Skip to content

Commit 43bfe01

Browse files
Reflect all changes in PR 25 and verify on new policy synthesis
1 parent 9b4c959 commit 43bfe01

File tree

2 files changed

+39
-22
lines changed

2 files changed

+39
-22
lines changed

scripts/synthesize_general_step_policy.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import os
1919
import pickle
20+
from typing import Optional
2021

2122
import numpy as np
2223
from numpy.typing import ArrayLike
@@ -43,10 +44,10 @@ def run_step_policy_synthesis(
4344
major_axis_length: float,
4445
risk_budget_shape_parameter: float = 0.0,
4546
use_p_norm: bool = False,
46-
custom_differential_risk_budget: ArrayLike = None,
47-
dead_time: int = None,
47+
custom_differential_risk_budget: Optional[ArrayLike] = None,
48+
dead_time: Optional[int] = None,
4849
save_policy_array: bool = False,
49-
save_policy_path: str = None,
50+
save_policy_path: Optional[str] = None,
5051
verbose: bool = False,
5152
):
5253
"""Procedure to synthesize a near-optimal finite-sample test for the policy comparison problem (assuming p1 > p0 is the alternative). This is the foundation for the
@@ -68,9 +69,10 @@ def run_step_policy_synthesis(
6869
6970
Raises:
7071
ValueError: If invalid required arguments
72+
ValueError: If inconsistent specification of whether to save the uncompressed policy
7173
ValueError: If invalid specified dead_time
72-
ValueError: If cumulative mass removal arrays do not terminate near alpha (making the procedure either loose, if below alpha, or invalid, if above alpha)
73-
ValueError: If control points are not assigned with proper extremal (min and max) limits
74+
RuntimeError: If cumulative mass removal arrays do not terminate near alpha (making the procedure either loose, if below alpha, or invalid, if above alpha)
75+
RuntimeError: If control points are not assigned with proper extremal (min and max) limits
7476
7577
Returns:
7678
POLICY_LIST_COMPRESSED (ArrayLike): The compressed representation of the accept/reject comparison policy.
@@ -125,15 +127,14 @@ def run_step_policy_synthesis(
125127
assert np.isclose(alpha, cumulative_mass_removal_array[-1])
126128
assert np.isclose(alpha, np.sum(diff_mass_removal_array))
127129
except:
128-
raise ValueError(
130+
raise RuntimeError(
129131
"Inconsistent cumulative and differential mass removal arrays; will lead to unpredictable optimization behavior!"
130132
)
131133

132134
##########
133135
# HANDLE Kernels, storage matrices, and encoding matrices
134136
# HANDLE capacity to compress the policy as we go
135137
##########
136-
# # TODO: more principled setup than the empirical shape parameters for quadratic_score
137138
# # Compute extremal 0 < p_min, p_max < 1 that contain risk of positive delta
138139
# p_max = np.exp(np.log(1.0 - alpha - 1e-5) / n_max)
139140
# p_min = 1.0 - p_max
@@ -152,7 +153,7 @@ def run_step_policy_synthesis(
152153
assert np.isclose(POINTS_ARRAY[-1], p_max)
153154
assert np.isclose(POINTS_ARRAY[0], p_min)
154155
except:
155-
raise ValueError(
156+
raise RuntimeError(
156157
"Error in assigning control points of worst-case null hypotheses; extremal values do not match [p_min, p_max]"
157158
)
158159

@@ -186,7 +187,7 @@ def run_step_policy_synthesis(
186187
)
187188
POLICY_LIST_COMPRESSED.append(policy_array_compressed)
188189
# Begin loop to synthesize the optimal policy
189-
for t in tqdm(range(1, n_max + 1)):
190+
for t in tqdm(range(1, n_max + 1), desc="STEP Policy Synthesis"):
190191
# Don't propagate zeros -- waste of time and effort
191192
critical_limit = int(np.minimum(n_max + 1, t + 1))
192193

@@ -521,9 +522,9 @@ def run_step_policy_synthesis(
521522
lambda_value = args.lambda_value
522523
major_axis_length = args.major_axis_length
523524

524-
base_path = os.getcwd()
525+
base_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
525526
results_path = f"sequentialized_barnard_tests/policies/n_max_{args.n_max}_alpha_{args.alpha}_shape_parameter_{args.log_p_norm}_pnorm_{args.use_p_norm}/"
526-
full_save_path = os.path.join(base_path, results_path)
527+
full_save_path = os.path.normpath(os.path.join(base_path, results_path))
527528
if not os.path.isdir(full_save_path):
528529
os.makedirs(full_save_path)
529530

@@ -548,8 +549,10 @@ def run_step_policy_synthesis(
548549
save_policy_path=special_policy_array_save_path,
549550
)
550551

551-
with open(full_save_path + "policy_compressed.pkl", "wb") as filename:
552+
with open(full_save_path + "/" + "policy_compressed.pkl", "wb") as filename:
552553
pickle.dump(POLICY_LIST_COMPRESSED, filename)
553554

554-
np.save(full_save_path + f"risk_accumulation.npy", RISK_ACCUMULATION)
555-
np.save(full_save_path + f"points_array.npy", POINTS_ARRAY)
555+
np.save(full_save_path + "/" + f"risk_accumulation.npy", RISK_ACCUMULATION)
556+
np.save(full_save_path + "/" + f"points_array.npy", POINTS_ARRAY)
557+
558+
print(f"STEP policy saved at {full_save_path}.")

scripts/visualize_step_policy.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,15 @@ def visualize_step_policy(
5757
# Set up and create the directory in which to save the appropriate images.
5858
policy_id_str = f"n_max_{n_max}_alpha_{alpha}_shape_parameter_{risk_budget_shape_parameter}_pnorm_{use_p_norm}/"
5959

60-
check_array_base_str = (
61-
f"sequentialized_barnard_tests/policies/" + policy_id_str + f"array/time_"
60+
base_dir = os.path.normpath(
61+
os.path.join(
62+
os.path.dirname(os.path.realpath(__file__)),
63+
"..",
64+
)
65+
)
66+
check_array_base_str = os.path.join(
67+
base_dir,
68+
f"sequentialized_barnard_tests/policies/" + policy_id_str + f"array/time_",
6269
)
6370
try:
6471
np.load(check_array_base_str + f"{5}.npy")
@@ -69,7 +76,7 @@ def visualize_step_policy(
6976
if compute_reconstruction_error_flag:
7077
error_by_timestep = np.zeros(n_max + 1)
7178

72-
media_save_path = "media/im/policies/" + policy_id_str
79+
media_save_path = os.path.join(base_dir, "media/im/policies/", policy_id_str)
7380

7481
if not os.path.isdir(media_save_path):
7582
os.makedirs(media_save_path)
@@ -90,7 +97,7 @@ def visualize_step_policy(
9097
fig2, ax2 = plt.subplots(figsize=(10, 10))
9198

9299
# Iterate through loop to generate policy_array and associated images
93-
for t in tqdm(range(n_max + 1)):
100+
for t in tqdm(range(n_max + 1), desc="STEP Policy Visualization"):
94101
try:
95102
del policy_array
96103
del decision_array_t
@@ -237,11 +244,18 @@ def visualize_step_policy(
237244
# TODO: add mirrored, alternative
238245

239246
args = parser.parse_args()
240-
247+
base_dir = os.path.normpath(
248+
os.path.join(
249+
os.path.dirname(os.path.realpath(__file__)),
250+
"..",
251+
)
252+
)
241253
policy_id_str = f"n_max_{args.n_max}_alpha_{args.alpha}_shape_parameter_{args.log_p_norm}_pnorm_{args.use_p_norm}/"
242-
full_load_str = f"sequentialized_barnard_tests/policies/" + policy_id_str
243-
media_save_path = "media/im/policies/" + policy_id_str
244-
scripts_save_path = "scripts/im/policies/" + policy_id_str
254+
full_load_str = os.path.join(
255+
base_dir, f"sequentialized_barnard_tests/policies/", policy_id_str
256+
)
257+
media_save_path = os.path.join(base_dir, "media/im/policies/", policy_id_str)
258+
scripts_save_path = os.path.join(base_dir, "scripts/im/policies/", policy_id_str)
245259

246260
if not os.path.isdir(media_save_path):
247261
os.makedirs(media_save_path)

0 commit comments

Comments
 (0)