Skip to content

Commit 6db47ef

Browse files
Merge pull request #32 from TRI-ML/auto_switch
Automatic switch between STEP and Lai
2 parents da05a57 + f095138 commit 6db47ef

File tree

6 files changed

+190
-34
lines changed

6 files changed

+190
-34
lines changed

README.md

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,7 @@ Features in development:
2929
## Installation Instructions \[Standard\]
3030
The basic environmental setup is shown below. A virtual / conda environment may be constructed; however, the requirements are quite lightweight and this is probably not needed.
3131
```bash
32-
$ cd <some_directory>
33-
$ git clone [email protected]:TRI-ML/sequentialized_barnard_tests.git
34-
$ cd sequentialized_barnard_tests
35-
$ pip install -r requirements.txt
36-
$ pip install -e .
32+
$ pip install sequentialized_barnard_tests
3733
```
3834

3935
## Installation Instructions \[Dev\]
@@ -52,9 +48,87 @@ $ pre-commit install
5248
We assume that any specified virtual / conda environment has been activated for all subsequent code snippets.
5349

5450
# Quick Start Guides
55-
We include key notes for understanding the core ideas of the STEP code. Quick-start resources are included in both shell script and notebook form.
51+
## Convenience: Automatic Test Instantiation
52+
53+
For convenience, you can automatically select between STEP and Lai (a baseline method) depending on the value of `n_max` using the factory function in `auto.py`:
54+
55+
```python
56+
from sequentialized_barnard_tests import get_mirrored_test
57+
test = get_mirrored_test(n_max, alternative, alpha, verbose=True, ...)
58+
```
59+
If `n_max > 500`, this will instantiate a `MirroredLaiTest`, which is a computationally efficient baseline with comparable performance to `MirroredStepTest` for a large-enough sample size; otherwise, it will use the more powerful `MirroredStepTest` which can take longer to synthesize the decision rule. All shared and class-specific arguments can be passed as keyword arguments.
60+
61+
## Example Usage
62+
63+
Below is a minimum example code with different policy evaluation data, leading to three distinct evaluation results.
64+
65+
### Case 1: Test yields `AcceptAlternative`
66+
```python
67+
from sequentialized_barnard_tests import get_mirrored_test, Hypothesis
68+
69+
n_max = 100 # maximum sample size is 100 (per policy)
70+
alternative = Hypothesis.P0LessThanP1 # we want to test if "success rate of the first policy < success rate of the second policy"
71+
alpha = 0.05 # false positive rate is 5%
72+
73+
test = get_mirrored_test(n_max=n_max, alternative=alternative, alpha=alpha)
74+
75+
success_array_policy_0 = [False] * 10 # the first policy failed 10 times
76+
success_array_policy_1 = [True] * 10 # the second policy succeeded 10 times
77+
78+
result = test.run_on_sequence(success_array_policy_0, success_array_policy_1)
79+
decision = result.decision
80+
print(decision) # AcceptAlternative: success rate of the first policy < success rate of the second policy with 95% confidence
81+
```
5682

57-
## Quick Start Guide: Making a STEP Policy for Specific \{n_max, alpha\}
83+
### Case 2: Test yields `AcceptNull`
84+
```python
85+
from sequentialized_barnard_tests import get_mirrored_test, Hypothesis
86+
87+
n_max = 100 # maximum sample size is 100 (per policy)
88+
alternative = Hypothesis.P0LessThanP1 # we want to test if "success rate of the first policy < success rate of the second policy"
89+
alpha = 0.05 # false positive rate is 5%
90+
91+
test = get_mirrored_test(n_max=n_max, alternative=alternative, alpha=alpha)
92+
93+
success_array_policy_0 = [True] * 10 # the first policy succeeded 10 times
94+
success_array_policy_1 = [False] * 10 # the second policy failed 10 times
95+
96+
result = test.run_on_sequence(success_array_policy_0, success_array_policy_1)
97+
decision = result.decision
98+
print(decision) # AcceptNull: success rate of the first policy > success rate of the second policy with 95% confidence
99+
```
100+
101+
Note: `AcceptNull` is a valid decision only for "mirrored" tests.
102+
In our terminology, a mirrored test is one that runs two one-sided tests
103+
simultaneously, with the null and the alternaive flipped from each other.
104+
(Because of the monotonicity of the test statistic, mirrored tests suffer no penalty for
105+
running two tests simultaneously, and therefore essentially dominate one-sided tests.)
106+
In the example above, the alternative is `Hypothesis.P0LessThanP1` and the decision is
107+
`Decision.AcceptNull`, which should be interpreted as accepting `Hypothesis.P0MoreThanP1`.
108+
If you rather want a more conventional one-sided test, you can instantiate one by calling
109+
`get_test` instead of `get_mirrored_test`.
110+
111+
### Case 3: Test yields `FailToDecide`
112+
```python
113+
from sequentialized_barnard_tests import get_mirrored_test, Hypothesis
114+
115+
n_max = 100 # maximum sample size is 100 (per policy)
116+
alternative = Hypothesis.P0LessThanP1 # we want to test if "success rate of the first policy < success rate of the second policy"
117+
alpha = 0.05 # false positive rate is 5%
118+
119+
test = get_mirrored_test(n_max=n_max, alternative=alternative, alpha=alpha)
120+
121+
success_array_policy_0 = [True, False, False, True] # the first policy succeeded 2 out of 4 times
122+
success_array_policy_1 = [False, True, True, True] # the second policy succeeded 3 out of 4 times
123+
124+
result = test.run_on_sequence(success_array_policy_0, success_array_policy_1)
125+
decision = result.decision
126+
print(decision) # FailToDecide: difference was not statistically separable; user can collect 100 - 4 = 96 more rollouts for each policy to re-run the test.
127+
```
128+
129+
## Key Notes for Understanding the Core Ideas of STEP Code
130+
131+
We include key notes for understanding the core ideas of the STEP code. Quick-start resources are included in both shell script and notebook form.
58132

59133
### (1A) Understanding the Accepted Shape Parameters
60134
In order to synthesize a STEP Policy for specific values of n_max and alpha, one additional set of parametric decisions will be required. The user will need to set the risk budget shape, which is specified by choice of function family (p-norm vs zeta-function) and particular shape parameter. The shape parameter is real-valued; it is used directly for zeta functions and is exponentiated for p-norms.
@@ -87,20 +161,22 @@ Generalizing the accepted risk budgets to arbitrary monotonic sequences $`\{0, \
87161
Having decided an appropriate form for the risk budget shape, policy synthesis is straightforward to run. From the base directory, the general command would be:
88162

89163
```bash
90-
$ python scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha} -pz {shape_parameter} -up {use_p_norm}
164+
$ python sequentialized_barnard_tests/scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha} -pz {shape_parameter} -up {use_p_norm}
91165
```
92166

167+
Note: This script will be called automatically upon instantiation of a test object, if the corresponding polciy file is missing from `sequentialized_barnard_tests/policies/`.
168+
93169
### (2B) What If I Don't Know the Right Risk Budget?
94170
We recommend using the default linear risk budget, which is the shape *used in the paper*. This corresponds to \{shape_parameter\}$`= 0.0`$ for each shape family. Thus, *each of the following commands constructs the same policy*:
95171

96172
```bash
97-
$ python scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha}
173+
$ python sequentialized_barnard_tests/scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha}
98174
```
99175
```bash
100-
$ python scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha} -pz {0.0} -up "True"
176+
$ python sequentialized_barnard_tests/scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha} -pz {0.0} -up "True"
101177
```
102178
```bash
103-
$ python scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha} -pz {0.0} -up "False"
179+
$ python sequentialized_barnard_tests/scripts/synthesize_general_step_policy.py -n {n_max} -a {alpha} -pz {0.0} -up "False"
104180
```
105181

106182
Note: For \{shape_parameter\} $`\neq 0`$, the shape families differ. Therefore, the choice of \{use_p_norm\} *will affect the STEP policy*.
@@ -113,7 +189,7 @@ $ sequentialized_barnard_tests/policies/
113189

114190
- At present, we have not tested extensively beyond \{n_max\}$`=500`$. Going beyond this limit may lead to issues, and the likelihood will grow the larger \{n_max\} is set to be. The code will also require increasing amounts of RAM as \{n_max\} is increased.
115191

116-
## Quick Start Guide: Evaluation on Real Data
192+
## Script-Based Evaluation on Real Data
117193

118194
We now assume that a STEP policy has been constructed for the target problem. This can either be one of the default policies, or a newly constructed one following the recipe in the preceding section.
119195

sequentialized_barnard_tests/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
from .lai import LaiTest, MirroredLaiTest
33
from .savi import MirroredSaviTest, SaviTest
44
from .step import MirroredStepTest, StepTest
5+
from .auto import get_test, get_mirrored_test
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Factory function for automatic selection of STEP or Lai based on n_max.
3+
"""
4+
5+
from sequentialized_barnard_tests.step import MirroredStepTest, StepTest
6+
from sequentialized_barnard_tests.lai import MirroredLaiTest, LaiTest
7+
8+
9+
def get_test(n_max: int, alternative, alpha: float, verbose: bool = False, **kwargs):
10+
"""
11+
Factory function to select StepTest or LaiTest based on n_max.
12+
Uses LaiTest for n_max > 500, otherwise StepTest.
13+
14+
Shared arguments:
15+
n_max (int): Maximal sequence length.
16+
alternative: Specification of the alternative hypothesis.
17+
alpha (float): Significance level of the test.
18+
verbose (bool, optional): If True, print outputs to stdout.
19+
Additional arguments for each class can be passed as keyword arguments.
20+
"""
21+
if n_max > 500:
22+
if verbose:
23+
print("Using LaiTest for n_max > 500")
24+
return LaiTest(alternative, n_max, alpha, verbose=verbose, **kwargs)
25+
else:
26+
if verbose:
27+
print("Using StepTest for n_max <= 500")
28+
return StepTest(alternative, n_max, alpha, verbose=verbose, **kwargs)
29+
30+
31+
def get_mirrored_test(
32+
n_max: int, alternative, alpha: float, verbose: bool = False, **kwargs
33+
):
34+
"""
35+
Factory function to select MirroredStepTest or MirroredLaiTest based on n_max.
36+
Uses MirroredLaiTest for n_max > 500, otherwise MirroredStepTest.
37+
38+
Shared arguments:
39+
n_max (int): Maximal sequence length.
40+
alternative: Specification of the alternative hypothesis.
41+
alpha (float): Significance level of the test.
42+
verbose (bool, optional): If True, print outputs to stdout.
43+
Additional arguments for each class can be passed as keyword arguments.
44+
"""
45+
if n_max > 500:
46+
if verbose:
47+
print("Using MirroredLaiTest for n_max > 500")
48+
return MirroredLaiTest(alternative, n_max, alpha, verbose=verbose, **kwargs)
49+
else:
50+
if verbose:
51+
print("Using MirroredStepTest for n_max <= 500")
52+
return MirroredStepTest(alternative, n_max, alpha, verbose=verbose, **kwargs)

scripts/synthesize_general_step_policy.py renamed to sequentialized_barnard_tests/scripts/synthesize_general_step_policy.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
55
Example Default Usage (all equivalent, using default params):
66
7-
(1) python scripts/synthesize_step_policy.py
8-
(2) python scripts/synthesize_step_policy.py -n 200 -a 0.05
9-
(3) python scripts/synthesize_step_policy.py --n_max 200 --alpha 0.05 --n_points 129
7+
(1) python sequentialized_barnard_tests/scripts/synthesize_step_policy.py
8+
(2) python sequentialized_barnard_tests/scripts/synthesize_step_policy.py -n 200 -a 0.05
9+
(3) python sequentialized_barnard_tests/scripts/synthesize_step_policy.py --n_max 200 --alpha 0.05 --n_points 129
1010
1111
Example Non-Default Parameter Usage:
1212
13-
python scripts/synthesize_step_policy.py -n 400
13+
python sequentialized_barnard_tests/scripts/synthesize_step_policy.py -n 400
1414
"""
1515

1616
import argparse
@@ -481,7 +481,7 @@ def run_step_policy_synthesis(
481481
"-up",
482482
"--use_p_norm",
483483
type=str,
484-
default=False,
484+
default="False",
485485
help=(
486486
"Toggle whether to use p_norm or zeta function shape family for the risk budget. "
487487
"If True, uses p_norm shape; else, uses zeta function shape family. "
@@ -530,7 +530,7 @@ def run_step_policy_synthesis(
530530
major_axis_length = args.major_axis_length
531531

532532
base_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
533-
results_path = f"sequentialized_barnard_tests/policies/n_max_{args.n_max}_alpha_{args.alpha}_shape_parameter_{args.log_p_norm}_pnorm_{args.use_p_norm}/"
533+
results_path = f"policies/n_max_{args.n_max}_alpha_{args.alpha}_shape_parameter_{args.log_p_norm}_pnorm_{args.use_p_norm}/"
534534
full_save_path = os.path.normpath(os.path.join(base_path, results_path))
535535
if not os.path.isdir(full_save_path):
536536
os.makedirs(full_save_path)

sequentialized_barnard_tests/step.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010

1111
import os
1212
import pickle
13+
import subprocess
14+
import sys
1315
import warnings
14-
from pathlib import Path
1516
from typing import Optional, Union
1617

1718
import numpy as np
@@ -323,15 +324,10 @@ def load_existing_policy(
323324
verbose (bool, optional): If True, print the outputs to stdout.
324325
Defaults to False.
325326
"""
326-
# print(str(Path(os.path.dirname(os.path.abspath(__file__))).resolve()))
327-
# print(os.path.dirname(os.path.abspath(__file__)))
328-
# print(os.path.dirname(__file__))
329-
# self.policy_path = os.path.join(
330-
# str(
331-
# Path(os.path.dirname(os.path.abspath(__file__))).resolve()
332-
# ), # os.path.dirname(os.path.abspath(__file__)),
333-
# f"policies/n_max_{self.n_max}_alpha_{self.alpha}_shape_parameter_{self.shape_parameter}_pnorm_{self.use_p_norm}/policy_compressed.pkl",
334-
# )
327+
# Determine the path to the synthesis script
328+
script_dir = os.path.join(os.path.dirname(__file__), "scripts")
329+
synth_script = os.path.join(script_dir, "synthesize_general_step_policy.py")
330+
335331
policy_path = os.path.join(
336332
os.path.dirname(__file__),
337333
f"policies/n_max_{self.n_max}_alpha_{self.alpha}_shape_parameter_{self.shape_parameter}_pnorm_{self.use_p_norm}/",
@@ -342,11 +338,41 @@ def load_existing_policy(
342338
with open(policy_path, "rb") as filename:
343339
self.policy = pickle.load(filename)
344340
self.need_new_policy = False
345-
except:
346-
warnings.warn(
347-
f"Current policy path: {policy_path}"
348-
"Unable to find policy with the assigned test parameters. An additional policy synthesis procedure may be required."
349-
)
341+
except Exception as e:
342+
print(f"Policy not found at {policy_path}. Attempting synthesis...")
343+
# Build command to synthesize policy
344+
cmd = [
345+
sys.executable,
346+
synth_script,
347+
"--n_max",
348+
str(self.n_max),
349+
"--alpha",
350+
str(self.alpha),
351+
"--n_points",
352+
"129", # default value, could be parameterized
353+
"--lambda_value",
354+
"2.1", # default, could be parameterized
355+
"--major_axis_length",
356+
"1.4", # default, could be parameterized
357+
"--log_p_norm",
358+
str(self.shape_parameter),
359+
"--use_p_norm",
360+
str(self.use_p_norm),
361+
]
362+
print(f"Running synthesis command: {' '.join(cmd)}")
363+
result = subprocess.run(cmd, cwd=script_dir, capture_output=False)
364+
# Try loading again
365+
try:
366+
with open(policy_path, "rb") as filename:
367+
self.policy = pickle.load(filename)
368+
self.need_new_policy = False
369+
print("Policy synthesis and loading successful.")
370+
except Exception as e2:
371+
warnings.warn(
372+
f"Unable to synthesize or load policy at {policy_path}. Error: {e2}"
373+
)
374+
self.policy = None
375+
self.need_new_policy = True
350376

351377
self.policy_path = policy_path
352378

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22

33
setup(
44
name="sequentialized_barnard_tests",
5-
version="0.0.1",
5+
version="0.0.4",
66
description="Sequential statistical hypothesis testing for two-by-two contingency tables.",
77
authors=["David Snyder", "Haruki Nishimura"],
88
author_emails=["[email protected]", "[email protected]"],
99
license="MIT",
1010
packages=find_packages(),
1111
package_data={
1212
"sequentialized_barnard_tests": [
13+
"scripts/synthesize_general_step_policy.py",
1314
"data/lai_calibration_data.npy",
1415
"policies/n_max_100_alpha_0.05_shape_parameter_0.0_pnorm_False/policy_compressed.pkl",
1516
"policies/n_max_200_alpha_0.05_shape_parameter_0.0_pnorm_False/policy_compressed.pkl",

0 commit comments

Comments
 (0)