Skip to content

Commit e39791a

Browse files
Incorporate Haruki comments; remove some extraneous commented sections; DO NOT include automatic call to synthesize_policy() pending clarification of __init__ issue.
1 parent 7fe5a68 commit e39791a

File tree

5 files changed

+13
-24
lines changed

5 files changed

+13
-24
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .base import Decision, Hypothesis, TestResult
2-
from .lai import LaiTest
3-
from .savi import SaviTest
4-
from .step import StepTest
2+
from .lai import LaiTest, MirroredLaiTest
3+
from .savi import MirroredSaviTest, SaviTest
4+
from .step import MirroredStepTest, StepTest

sequentialized_barnard_tests/step.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pickle
1313
import warnings
1414
from pathlib import Path
15-
from typing import Union
15+
from typing import Optional, Union
1616

1717
import numpy as np
1818

@@ -40,9 +40,6 @@ class StepTest(SequentialTestBase):
4040
use_p_norm (bool): whether to use p_norm shape (True) or partial sums of the zeta function (False).
4141
policy (List[ArrayLike]): the evaluation decision-making algorithm. Length n_max, each element is an associated array.
4242
need_new_policy (bool): indicator that a policy has not been previously synthesized for these test parameters.
43-
_state (ArrayLike): internal state for a particular test. Set to np.zeros(2) when the test is reset.
44-
_t (int): internal time state for a particular test. Set to 0 when the test is reset.
45-
_current_decision (Decision): internal decision state for a particular test. Set to FailToDecide when test is reset.
4643
"""
4744

4845
def __init__(
@@ -52,7 +49,7 @@ def __init__(
5249
alpha: float,
5350
shape_parameter: float = 0.0,
5451
use_p_norm: bool = False,
55-
random_seed: int = None,
52+
random_seed: Optional[int] = None,
5653
verbose: bool = False,
5754
) -> None:
5855
"""Initializes the test object.
@@ -289,9 +286,10 @@ def load_existing_policy(
289286
self.policy = pickle.load(filename)
290287
self.need_new_policy = False
291288
except:
292-
warnings.warn(f"Current policy path: {policy_path}")
293-
# "Unable to find policy with the assigned test parameters. An additional policy synthesis procedure may be required."
294-
# f"Current policy path: {policy_path}"
289+
warnings.warn(
290+
f"Current policy path: {policy_path}"
291+
"Unable to find policy with the assigned test parameters. An additional policy synthesis procedure may be required."
292+
)
295293

296294
self.policy_path = policy_path
297295

@@ -315,9 +313,6 @@ class MirroredStepTest(StepTest):
315313
use_p_norm (bool): whether to use p_norm shape (True) or partial sums of the zeta function (False).
316314
policy (List[ArrayLike]): the evaluation decision-making algorithm. Length n_max, each element is an associated array.
317315
need_new_policy (bool): indicator that a policy has not been previously synthesized for these test parameters.
318-
_state (ArrayLike): internal state for a particular test. Set to np.zeros(2) when the test is reset.
319-
_t (int): internal time state for a particular test. Set to 0 when the test is reset.
320-
_current_decision (Decision): internal decision state for a particular test. Set to FailToDecide when test is reset.
321316
"""
322317

323318
def __init__(
@@ -534,12 +529,3 @@ def step(
534529
result = TestResult(self._current_decision, info)
535530

536531
return result
537-
538-
539-
# if __name__ == "__main__":
540-
# step = StepTest(Hypothesis.P0LessThanP1, 200, 0.05)
541-
# print(step.policy_path)
542-
# print()
543-
# step = StepTest(Hypothesis.P0LessThanP1, 500, 0.05)
544-
# print(step.policy_path)
545-
# print()

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
package_data={
1212
"sequentialized_barnard_tests": [
1313
"data/lai_calibration_data.npy",
14+
"policies/n_max_100_alpha_0.05_shape_parameter_0.0_pnorm_False/policy_compressed.pkl",
15+
"policies/n_max_200_alpha_0.05_shape_parameter_0.0_pnorm_False/policy_compressed.pkl",
16+
"policies/n_max_500_alpha_0.05_shape_parameter_0.0_pnorm_False/policy_compressed.pkl",
1417
],
1518
},
1619
install_requires=[

tests/sequentialized_barnard_tests/test_step.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
Path(
1515
os.path.join(
1616
os.path.dirname(os.path.abspath(__file__)),
17-
"../../sequentialized_barnard_tests/eval_data/",
17+
"../eval_data/",
1818
)
1919
).resolve()
2020
)

0 commit comments

Comments
 (0)