Skip to content

Commit 8f8d638

Browse files
Add all paper data and incorporate into STEP pytest checks
1 parent e39791a commit 8f8d638

File tree

6 files changed

+152
-51
lines changed

6 files changed

+152
-51
lines changed

sequentialized_barnard_tests/step.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -166,21 +166,18 @@ def step(
166166
x = int(self._state[0])
167167
y = int(self._state[1])
168168

169-
if (y > x and self.alternative == Hypothesis.P0LessThanP1) or (
170-
x > y and self.alternative == Hypothesis.P0MoreThanP1
171-
):
172-
if y > x:
173-
x_absolute = x
174-
y_absolute = y
175-
else:
176-
x_absolute = y
177-
y_absolute = x
169+
if y > x:
170+
# and self.alternative == Hypothesis.P0LessThanP1) or (
171+
# x > y and self.alternative == Hypothesis.P0MoreThanP1
172+
# ):
173+
x_absolute = x
174+
y_absolute = y
178175

179176
# New policy > old policy (empirically)
180177
# Therefore, look only to REJECT in standard setting
181178

182179
# Extract relevant component of policy
183-
decision_array = self.policy[self._t - 1][x_absolute]
180+
decision_array = self.policy[self._t][x_absolute]
184181

185182
# Number of non-zero / non-unity policy bins at this x and t
186183
L = decision_array.shape[0] - 1
@@ -189,16 +186,19 @@ def step(
189186
critical_zero_y = int(decision_array[0])
190187

191188
if y_absolute <= critical_zero_y: # Current state cannot be significant
192-
info = {"Time": self._t, "State": self._state}
189+
info = {"Time": self._t + 1, "State": self._state}
193190
result = TestResult(self._current_decision, info)
194191

195192
return result
196193

197194
elif (
198195
y_absolute > critical_zero_y + L
199196
): # Current state is definitely significant
200-
self._current_decision = Decision.AcceptAlternative
201-
info = {"Time": self._t, "State": self._state}
197+
if self.alternative == Hypothesis.P0LessThanP1:
198+
self._current_decision = Decision.AcceptAlternative
199+
else:
200+
self._current_decision = Decision.FailToDecide
201+
info = {"Time": self._t + 1, "State": self._state}
202202
result = TestResult(self._current_decision, info)
203203

204204
return result
@@ -212,7 +212,65 @@ def step(
212212
if (
213213
random_scalar <= comparator_rv
214214
): # Then we have probabilistically rejected
215+
if self.alternative == Hypothesis.P0LessThanP1:
216+
self._current_decision = Decision.AcceptAlternative
217+
else:
218+
self._current_decision = Decision.FailToDecide
219+
info = {"Time": self._t + 1, "State": self._state}
220+
result = TestResult(self._current_decision, info)
221+
else: # Then we have probabilistically continued
222+
info = {"Time": self._t + 1, "State": self._state}
223+
result = TestResult(self._current_decision, info)
224+
225+
return result
226+
227+
elif x > y:
228+
x_absolute = y
229+
y_absolute = x
230+
231+
# New policy > old policy (empirically)
232+
# Therefore, look only to REJECT in reverse setting
233+
234+
# Extract relevant component of policy
235+
decision_array = self.policy[self._t][x_absolute]
236+
237+
# Number of non-zero / non-unity policy bins at this x and t
238+
L = decision_array.shape[0] - 1
239+
240+
# Highest value of y for which we CONTINUE [i.e., policy = 0]
241+
critical_zero_y = int(decision_array[0])
242+
243+
if y_absolute <= critical_zero_y: # Current state cannot be significant
244+
info = {"Time": self._t + 1, "State": self._state}
245+
result = TestResult(self._current_decision, info)
246+
247+
return result
248+
249+
elif (
250+
y_absolute > critical_zero_y + L
251+
): # Current state is definitely significant
252+
if self.alternative == Hypothesis.P0MoreThanP1:
215253
self._current_decision = Decision.AcceptAlternative
254+
else:
255+
self._current_decision = Decision.FailToDecide
256+
info = {"Time": self._t + 1, "State": self._state}
257+
result = TestResult(self._current_decision, info)
258+
259+
return result
260+
261+
else: # Current state is in probabilistic regime
262+
# random_scalar = np.random.rand(
263+
# 1
264+
# ) # TODO: add some kind of seeding procedure to ensure repeatibility
265+
random_scalar = self.rng.random(1)
266+
comparator_rv = decision_array[y_absolute - critical_zero_y]
267+
if (
268+
random_scalar <= comparator_rv
269+
): # Then we have probabilistically rejected
270+
if self.alternative == Hypothesis.P0MoreThanP1:
271+
self._current_decision = Decision.AcceptAlternative
272+
else:
273+
self._current_decision = Decision.FailToDecide
216274
info = {"Time": self._t, "State": self._state}
217275
result = TestResult(self._current_decision, info)
218276
else: # Then we have probabilistically continued
@@ -221,8 +279,8 @@ def step(
221279

222280
return result
223281
else:
224-
# Cannot reject; as test is one-sided, can only continue!
225-
info = {"Time": self._t, "State": self._state}
282+
# Cannot reject because delta is exactly 0; can only continue!
283+
info = {"Time": self._t + 1, "State": self._state}
226284
result = TestResult(self._current_decision, info)
227285

228286
return result
@@ -238,7 +296,7 @@ def reset(
238296
verbose (bool, optional): If True, print the outputs to stdout.
239297
Defaults to False.
240298
"""
241-
self._state = np.zeros(2)
299+
self._state = np.zeros(2).astype(int)
242300
self._t = int(0)
243301
self._current_decision = Decision.FailToDecide
244302

@@ -377,19 +435,20 @@ def step(
377435
)
378436
)
379437

380-
# Iterate time state
381-
self._t += 1
382-
383438
# Handle case in which we have exceeded n_max
384439
if self._t > self.n_max:
385440
warnings.warn(
386441
"Have exceeded the allowed number of evals; not updating internal states."
387442
)
443+
self._t += 1
388444
info = {"Time": self._t, "State": self._state}
389445
result = TestResult(self._current_decision, info)
390446

391447
return result
392448

449+
# Iterate time state
450+
self._t += 1
451+
393452
if self.policy is None:
394453
# warnings.warn(
395454
# "No policy assigned, so will default to Fail to Decide. Ensure "
@@ -431,7 +490,7 @@ def step(
431490
critical_zero_y = int(decision_array[0])
432491

433492
if y_absolute <= critical_zero_y: # Current state cannot be significant
434-
info = {"Time": self._t, "State": self._state}
493+
info = {"Time": self._t + 1, "State": self._state}
435494
result = TestResult(self._current_decision, info)
436495

437496
return result
@@ -443,7 +502,7 @@ def step(
443502
self._current_decision = Decision.AcceptAlternative
444503
else:
445504
self._current_decision = Decision.AcceptNull
446-
info = {"Time": self._t, "State": self._state}
505+
info = {"Time": self._t + 1, "State": self._state}
447506
result = TestResult(self._current_decision, info)
448507

449508
return result
@@ -461,10 +520,10 @@ def step(
461520
self._current_decision = Decision.AcceptAlternative
462521
else:
463522
self._current_decision = Decision.AcceptNull
464-
info = {"Time": self._t, "State": self._state}
523+
info = {"Time": self._t + 1, "State": self._state}
465524
result = TestResult(self._current_decision, info)
466525
else: # Then we have probabilistically continued
467-
info = {"Time": self._t, "State": self._state}
526+
info = {"Time": self._t + 1, "State": self._state}
468527
result = TestResult(self._current_decision, info)
469528

470529
return result
@@ -486,7 +545,7 @@ def step(
486545
critical_zero_y = int(decision_array[0])
487546

488547
if y_absolute <= critical_zero_y: # Current state cannot be significant
489-
info = {"Time": self._t, "State": self._state}
548+
info = {"Time": self._t + 1, "State": self._state}
490549
result = TestResult(self._current_decision, info)
491550

492551
return result
@@ -498,7 +557,7 @@ def step(
498557
self._current_decision = Decision.AcceptAlternative
499558
else:
500559
self._current_decision = Decision.AcceptNull
501-
info = {"Time": self._t, "State": self._state}
560+
info = {"Time": self._t + 1, "State": self._state}
502561
result = TestResult(self._current_decision, info)
503562

504563
return result
@@ -525,7 +584,7 @@ def step(
525584
return result
526585
else:
527586
# Cannot reject because delta is exactly 0; can only continue!
528-
info = {"Time": self._t, "State": self._state}
587+
info = {"Time": self._t + 1, "State": self._state}
529588
result = TestResult(self._current_decision, info)
530589

531590
return result
928 Bytes
Binary file not shown.
7.94 KB
Binary file not shown.
7.94 KB
Binary file not shown.
7.94 KB
Binary file not shown.

tests/sequentialized_barnard_tests/test_step.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,21 @@
1818
)
1919
).resolve()
2020
)
21-
eval_trajectories = np.load(f"{paper_data_path}/TRI_CLEAN_SPILL_v2.npy")
21+
eval_clean_up_spill = np.load(
22+
f"{paper_data_path}/TRI_CLEAN_SPILL_v2.npy"
23+
) # Must be flipped for standard form
24+
eval_fold_red_towel = np.load(
25+
f"{paper_data_path}/TRI_FOLD_RED_TOWEL.npy"
26+
) # ALREADY in standard form
27+
eval_sim_spoon_on_towel = np.load(
28+
f"{paper_data_path}/TRI_SIM_SPOON_ON_TOWEL.npy"
29+
) # Must be flipped for standard form
30+
eval_sim_eggplant_in_basket = np.load(
31+
f"{paper_data_path}/TRI_SIM_EGGPLANT_IN_BASKET.npy"
32+
) # Must be flipped for standard form
33+
eval_sim_stack_cube = np.load(
34+
f"{paper_data_path}/TRI_SIM_STACK_CUBE.npy"
35+
) # Must be flipped for standard form
2236

2337

2438
@pytest.fixture(scope="module")
@@ -63,10 +77,10 @@ def test_step_input_value_error(step):
6377
(Hypothesis.P0MoreThanP1, np.zeros(15), np.ones(15), Decision.FailToDecide),
6478
(Hypothesis.P0LessThanP1, np.ones(15), np.zeros(15), Decision.FailToDecide),
6579
(Hypothesis.P0MoreThanP1, np.ones(15), np.zeros(15), Decision.AcceptAlternative),
66-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], Decision.AcceptAlternative),
67-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], Decision.FailToDecide),
68-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], Decision.FailToDecide),
69-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], Decision.AcceptAlternative),
80+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], Decision.AcceptAlternative),
81+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], Decision.FailToDecide),
82+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], Decision.FailToDecide),
83+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], Decision.AcceptAlternative),
7084
# fmt: on
7185
],
7286
indirect=["step"],
@@ -80,17 +94,21 @@ def test_step(step, sequence_0, sequence_1, expected):
8094
("step", "sequence_0", "sequence_1", "expected"),
8195
[
8296
# fmt: off
83-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 22.5),
84-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 50),
85-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 50),
86-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 22.5),
97+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 23),
98+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 50),
99+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 50),
100+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 23),
101+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 21.5),
102+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 50),
103+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 50),
104+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 21.5),
87105
# fmt: on
88106
],
89107
indirect=["step"],
90108
)
91109
def test_step_time(step, sequence_0, sequence_1, expected):
92110
result = step.run_on_sequence(sequence_0, sequence_1)
93-
assert np.abs(float(result.info["Time"]) - expected) <= 3.0
111+
assert np.abs(float(result.info["Time"]) - expected) <= 1.2
94112

95113

96114
@pytest.fixture(scope="module")
@@ -108,17 +126,21 @@ def step500(request):
108126
("step500", "sequence_0", "sequence_1", "expected"),
109127
[
110128
# fmt: off
111-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 33),
112-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 50),
113-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 50),
114-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 33),
129+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 25.5),
130+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 50),
131+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 50),
132+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 25.5),
133+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 23.5),
134+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 50),
135+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 50),
136+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 23.5),
115137
# fmt: on
116138
],
117139
indirect=["step500"],
118140
)
119141
def test_step500_time(step500, sequence_0, sequence_1, expected):
120142
result = step500.run_on_sequence(sequence_0, sequence_1)
121-
assert np.abs(result.info["Time"] - expected) <= 1.5
143+
assert np.abs(result.info["Time"] - expected) <= 0.6
122144

123145

124146
##### Mirrored STEP Test #####
@@ -160,17 +182,21 @@ def test_mirrored_step(mirrored_step, sequence_0, sequence_1, expected):
160182
("mirrored_step", "sequence_0", "sequence_1", "expected"),
161183
[
162184
# fmt: off
163-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 25),
164-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 25),
165-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 25),
166-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 25),
185+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 23.5),
186+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 23.5),
187+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 23.5),
188+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 23.5),
189+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 21.5),
190+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 21.5),
191+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 21.5),
192+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 21.5),
167193
# fmt: on
168194
],
169195
indirect=["mirrored_step"],
170196
)
171197
def test_mirrored_step_time(mirrored_step, sequence_0, sequence_1, expected):
172198
result = mirrored_step.run_on_sequence(sequence_0, sequence_1)
173-
assert np.abs(result.info["Time"] - expected) <= 1.5
199+
assert np.abs(result.info["Time"] - expected) <= 0.6
174200

175201

176202
@pytest.fixture(scope="module")
@@ -188,14 +214,30 @@ def mirrored_step500(request):
188214
("mirrored_step500", "sequence_0", "sequence_1", "expected"),
189215
[
190216
# fmt: off
191-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 33),
192-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 1], eval_trajectories[:, 0], 33),
193-
(Hypothesis.P0LessThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 33),
194-
(Hypothesis.P0MoreThanP1, eval_trajectories[:, 0], eval_trajectories[:, 1], 33),
217+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 25.5),
218+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 1], eval_clean_up_spill[:, 0], 25.5),
219+
(Hypothesis.P0LessThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 25.5),
220+
(Hypothesis.P0MoreThanP1, eval_clean_up_spill[:, 0], eval_clean_up_spill[:, 1], 25.5),
221+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 23.5),
222+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 0], eval_fold_red_towel[:, 1], 23.5),
223+
(Hypothesis.P0LessThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 23.5),
224+
(Hypothesis.P0MoreThanP1, eval_fold_red_towel[:, 1], eval_fold_red_towel[:, 0], 23.5),
225+
(Hypothesis.P0LessThanP1, eval_sim_spoon_on_towel[:, 1], eval_sim_spoon_on_towel[:, 0], 32.5),
226+
(Hypothesis.P0MoreThanP1, eval_sim_spoon_on_towel[:, 1], eval_sim_spoon_on_towel[:, 0], 32.5),
227+
(Hypothesis.P0LessThanP1, eval_sim_spoon_on_towel[:, 0], eval_sim_spoon_on_towel[:, 1], 32.5),
228+
(Hypothesis.P0MoreThanP1, eval_sim_spoon_on_towel[:, 0], eval_sim_spoon_on_towel[:, 1], 32.5),
229+
(Hypothesis.P0LessThanP1, eval_sim_eggplant_in_basket[:, 1], eval_sim_eggplant_in_basket[:, 0], 119.5),
230+
(Hypothesis.P0MoreThanP1, eval_sim_eggplant_in_basket[:, 1], eval_sim_eggplant_in_basket[:, 0], 119.5),
231+
(Hypothesis.P0LessThanP1, eval_sim_eggplant_in_basket[:, 0], eval_sim_eggplant_in_basket[:, 1], 119.5),
232+
(Hypothesis.P0MoreThanP1, eval_sim_eggplant_in_basket[:, 0], eval_sim_eggplant_in_basket[:, 1], 119.5),
233+
(Hypothesis.P0LessThanP1, eval_sim_stack_cube[:, 1], eval_sim_stack_cube[:, 0], 172.5),
234+
(Hypothesis.P0MoreThanP1, eval_sim_stack_cube[:, 1], eval_sim_stack_cube[:, 0], 172.5),
235+
(Hypothesis.P0LessThanP1, eval_sim_stack_cube[:, 0], eval_sim_stack_cube[:, 1], 172.5),
236+
(Hypothesis.P0MoreThanP1, eval_sim_stack_cube[:, 0], eval_sim_stack_cube[:, 1], 172.5),
195237
# fmt: on
196238
],
197239
indirect=["mirrored_step500"],
198240
)
199241
def test_mirrored_step500_time(mirrored_step500, sequence_0, sequence_1, expected):
200242
result = mirrored_step500.run_on_sequence(sequence_0, sequence_1)
201-
assert np.abs(result.info["Time"] - expected) <= 1.5
243+
assert np.abs(result.info["Time"] - expected) <= 0.6

0 commit comments

Comments
 (0)