Skip to content

Commit 4a517ef

Browse files
Merge pull request #21 from TRI-ML/step_evaluation_dev
Add step.py as well as unit tests and sample policies. TODO; automatic call to policy synthesis
2 parents b458c9b + 8f8d638 commit 4a517ef

File tree

18 files changed

+1068
-6
lines changed

18 files changed

+1068
-6
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.

scripts/synthesize_step_policy.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def run_step_policy_synthesis(
220220
pass
221221
else:
222222
if idx0 + idx1 == t:
223-
FEATURES_BASE[:, feature_counter] = copy.deepcopy(
223+
FEATURES_BASE[:, feature_counter] = 2.0 * copy.deepcopy(
224224
STATE_DIST_POST[idx0, idx1, :]
225225
)
226226
DISPOSABLE_CANDIDATE_STATE_ENCODING[idx0, idx1] = 0.0
@@ -394,6 +394,25 @@ def run_step_policy_synthesis(
394394

395395
args = parser.parse_args()
396396

397+
if args.n_max == 100:
398+
lambda_value = 2.1
399+
major_axis_length = 1.4
400+
elif args.n_max == 200:
401+
lambda_value = 2.1
402+
major_axis_length = 1.15
403+
elif args.n_max == 300:
404+
lambda_value = 2.1
405+
major_axis_length = 1.4
406+
elif args.n_max == 400:
407+
lambda_value = 2.1
408+
major_axis_length = 1.4
409+
elif args.n_max == 500:
410+
lambda_value = 2.2
411+
major_axis_length = 1.35
412+
else:
413+
lambda_value = args.lambda_value
414+
major_axis_length = args.major_axis_length
415+
397416
(
398417
POLICY_LIST_COMPRESSED,
399418
RISK_ACCUMULATION,
@@ -402,8 +421,8 @@ def run_step_policy_synthesis(
402421
args.n_max,
403422
args.alpha,
404423
args.n_points,
405-
args.lambda_value,
406-
args.major_axis_length,
424+
lambda_value,
425+
major_axis_length,
407426
args.log_p_norm,
408427
args.use_p_norm,
409428
)

scripts/visualize_step_policy.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Method to construct visualizations of the STEP near-optimal decision making policy.
2+
3+
This is primarily a debugging tool, useful for the designer to visually verify that the
4+
policy is incorporating states in a logical / explainable manner.
5+
"""
6+
7+
import argparse
8+
import copy
9+
import os
10+
11+
import numpy as np
12+
from matplotlib import pyplot as plt
13+
from tqdm import tqdm
14+
15+
from sequentialized_barnard_tests import StepTest
16+
from sequentialized_barnard_tests.base import Decision, Hypothesis
17+
18+
19+
def visualize_step_policy(
20+
n_max: int,
21+
alpha: float,
22+
risk_budget_shape_parameter: float = 0.0,
23+
use_p_norm: bool = False,
24+
mirrored: bool = True,
25+
alternative: Hypothesis = Hypothesis.P0LessThanP1,
26+
):
27+
STEP_test = StepTest(
28+
alternative, n_max, alpha, risk_budget_shape_parameter, use_p_norm
29+
)
30+
STEP_test.load_existing_policy()
31+
32+
if STEP_test.policy is None:
33+
raise ValueError(
34+
"Unable to find a policy with these parameters. Please double check or run appropriate policy synthesis. "
35+
)
36+
37+
# Set up and create the directory in which to save the appropriate images.
38+
policy_id_str = f"n_max_{n_max}_alpha_{alpha}_shape_parameter_{risk_budget_shape_parameter}_pnorm_{use_p_norm}/"
39+
media_save_path = "media/im/" + policy_id_str
40+
41+
if not os.path.isdir(media_save_path):
42+
os.makedirs(media_save_path)
43+
44+
# Extract STEP test policy and query in order to construct images
45+
policy_to_visualize = copy.deepcopy(STEP_test.policy)
46+
47+
try:
48+
assert len(policy_to_visualize) == n_max
49+
except:
50+
print(
51+
f"Issue with policy consistency; should be length {n_max}, but is length {len(policy_to_visualize)}"
52+
)
53+
raise ValueError(
54+
"policy appears to be of incorrect length. Please verify the synthesis procedure."
55+
)
56+
57+
fig2, ax2 = plt.subplots(figsize=(10, 10))
58+
59+
# Iterate through loop to generate policy_array and associated images
60+
for t in tqdm(range(n_max + 1)):
61+
try:
62+
del policy_array
63+
del decision_array_t
64+
except:
65+
pass
66+
67+
policy_array = np.zeros((t + 1, t + 1))
68+
if t >= 1:
69+
decision_array_t = policy_to_visualize[t - 1]
70+
else:
71+
decision_array_t = [0]
72+
73+
for i in range(t + 1):
74+
for j in range(i, t + 1):
75+
x_absolute = min(i, j)
76+
y_absolute = max(i, j)
77+
78+
if y_absolute - x_absolute > 0:
79+
decision_array = decision_array_t[x_absolute]
80+
# Number of non-zero / non-unity policy bins at this x and t
81+
L = decision_array.shape[0] - 1
82+
83+
# Highest value of y for which we CONTINUE [i.e., policy = 0]
84+
critical_zero_y = int(decision_array[0])
85+
86+
if mirrored:
87+
# Find the decision and assign it to [x_abs, y_abs], and assign negation to [y_abs, x_abs]
88+
if y_absolute <= critical_zero_y:
89+
pass
90+
elif y_absolute > (critical_zero_y + L):
91+
policy_array[x_absolute, y_absolute] = 1.0
92+
policy_array[y_absolute, x_absolute] = -1.0
93+
else:
94+
prob_stop = decision_array[y_absolute - critical_zero_y]
95+
policy_array[x_absolute, y_absolute] = prob_stop
96+
policy_array[y_absolute, x_absolute] = -prob_stop
97+
98+
elif alternative == Hypothesis.P0LessThanP1:
99+
# Find the decision and assign it to [x_abs, y_abs], and assign negation to [y_abs, x_abs]
100+
if y_absolute <= critical_zero_y:
101+
pass
102+
elif y_absolute > (critical_zero_y + L):
103+
policy_array[x_absolute, y_absolute] = 1.0
104+
else:
105+
prob_stop = decision_array[y_absolute - critical_zero_y]
106+
policy_array[x_absolute, y_absolute] = prob_stop
107+
108+
else:
109+
# Find the decision and assign it to [x_abs, y_abs], and assign negation to [y_abs, x_abs]
110+
if y_absolute <= critical_zero_y:
111+
pass
112+
elif y_absolute > (critical_zero_y + L):
113+
policy_array[y_absolute, x_absolute] = -1.0
114+
else:
115+
prob_stop = decision_array[y_absolute - critical_zero_y]
116+
policy_array[y_absolute, x_absolute] = -prob_stop
117+
118+
# Save off policy array as an image
119+
ax2.cla()
120+
# plt.cla()
121+
122+
# ax.imshow(np.transpose(SIGN_ARRAY), cmap='RdYlGn', origin='lower')
123+
ax2.pcolormesh(
124+
np.arange(t + 2) / (t + 1),
125+
np.arange(t + 2) / (t + 1),
126+
np.transpose(policy_array),
127+
cmap="RdYlBu", # "RdYlGn",
128+
vmin=-1.2,
129+
vmax=1.2,
130+
)
131+
ax2.plot([0, 1], [0, 1], "k--", linewidth=5)
132+
ax2.set_xlabel("Baseline Performance", fontsize=24)
133+
ax2.set_ylabel("Test Policy Performance", fontsize=24)
134+
ax2.tick_params(labelsize=20)
135+
ax2.text(0.05, 0.95, f"n = {t}", color="#FFFFFF", fontsize=24, weight="heavy")
136+
ax2.set_aspect("equal")
137+
ax2.grid(True)
138+
fig2.savefig(media_save_path + f"{t:03d}.png", dpi=450)
139+
140+
return 1
141+
142+
143+
if __name__ == "__main__":
144+
parser = argparse.ArgumentParser(
145+
description=(
146+
"This script synthesizes a near-optimal STEP policy for a given "
147+
"{n_max, alpha} combination. The results are saved to a .npy file at "
148+
"'sequentialized_barnard_tests/policies'. Some parameters of the STEP "
149+
"policy's synthesis procedure can have important numerical effects "
150+
"on the resulting efficiency of computation."
151+
)
152+
)
153+
parser.add_argument(
154+
"-n",
155+
"--n_max",
156+
type=int,
157+
default=200,
158+
help=(
159+
"Maximum number of robot policy evals (per policy) in the evaluation procedure. "
160+
"Defaults to 200."
161+
),
162+
)
163+
parser.add_argument(
164+
"-a",
165+
"--alpha",
166+
type=float,
167+
default=0.05,
168+
help=(
169+
"Maximal allowed Type-1 error rate of the statistical testing procedure. "
170+
"Defaults to 0.05."
171+
),
172+
)
173+
parser.add_argument(
174+
"-pz",
175+
"--log_p_norm",
176+
type=float,
177+
default=0.0,
178+
help=(
179+
"Rate at which risk is accumulated, reflecting user's belief about underlying "
180+
"likelihood of different alternatives and nulls being true. If using a p_norm "
181+
", this variable is equivalent to log(p). If not using a p_norm, this is the "
182+
"argument to the zeta function, partial sums of which give the shape of the risk budget."
183+
"Defaults to 0.0."
184+
),
185+
)
186+
parser.add_argument(
187+
"-up",
188+
"--use_p_norm",
189+
type=bool,
190+
default=False,
191+
help=(
192+
"Toggle whether to use p_norm or zeta function shape family for the risk budget. "
193+
"If True, uses p_norm shape; else, uses zeta function shape family. "
194+
"Defaults to False (zeta function partial sum family)."
195+
),
196+
)
197+
# TODO: add mirrored, alternative
198+
199+
args = parser.parse_args()
200+
201+
exit_status = visualize_step_policy(
202+
args.n_max, args.alpha, args.log_p_norm, args.use_p_norm
203+
)
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from .base import Decision, Hypothesis, TestResult
2-
from .lai import LaiTest
2+
from .lai import LaiTest, MirroredLaiTest
3+
from .savi import MirroredSaviTest, SaviTest
4+
from .step import MirroredStepTest, StepTest

0 commit comments

Comments
 (0)