Skip to content

Commit 841fa7c

Browse files
committed
Further debug interfac.e
1 parent 859a11f commit 841fa7c

File tree

3 files changed

+174
-80
lines changed

3 files changed

+174
-80
lines changed

src/dcegm/final_periods.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,16 @@ def solve_last_two_periods(
4848
for all states, end of period assets, and income shocks.
4949
5050
"""
51-
52-
idx_state_choices_final_period = last_two_period_batch_info[
53-
"idx_state_choices_final_period"
54-
]
55-
if debug_info is not None:
56-
if "rescale_idx" in debug_info.keys():
57-
# If we want to rescale the idx, because we only solve part of the model, then to this first.
58-
idx_state_choices_final_period = (
59-
idx_state_choices_final_period - debug_info["rescale_idx"]
60-
)
61-
6251
(
6352
value_solved,
6453
policy_solved,
6554
endog_grid_solved,
6655
value_interp_final_period,
6756
marginal_utility_final_last_period,
6857
) = solve_final_period(
69-
idx_state_choices_final_period=idx_state_choices_final_period,
58+
idx_state_choices_final_period=last_two_period_batch_info[
59+
"idx_state_choices_final_period"
60+
],
7061
idx_parent_states_final_period=last_two_period_batch_info[
7162
"idxs_parent_states_final_period"
7263
],
@@ -118,38 +109,25 @@ def solve_last_two_periods(
118109

119110
idx_second_last = last_two_period_batch_info["idx_state_choices_second_last_period"]
120111

112+
value_solved = value_solved.at[idx_second_last, ...].set(
113+
out_dict_second_last["value"]
114+
)
115+
policy_solved = policy_solved.at[idx_second_last, ...].set(
116+
out_dict_second_last["policy"]
117+
)
118+
endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set(
119+
out_dict_second_last["endog_grid"]
120+
)
121+
121122
# If we do not call the function in debug mode. Assign everything and return
122123
if debug_info is None:
123-
value_solved = value_solved.at[idx_second_last, ...].set(
124-
out_dict_second_last["value"]
125-
)
126-
policy_solved = policy_solved.at[idx_second_last, ...].set(
127-
out_dict_second_last["policy"]
128-
)
129-
endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set(
130-
out_dict_second_last["endog_grid"]
131-
)
132124
return (
133125
value_solved,
134126
policy_solved,
135127
endog_grid_solved,
136128
)
137129

138130
else:
139-
if "rescale_idx" in debug_info.keys():
140-
# If we want to rescale the idx, because we only solve part of the model, then to this first.
141-
idx_rescaled_second_last = idx_second_last - debug_info["rescale_idx"]
142-
# And then assign to the solution containers.
143-
value_solved = value_solved.at[idx_rescaled_second_last, ...].set(
144-
out_dict_second_last["value"]
145-
)
146-
policy_solved = policy_solved.at[idx_rescaled_second_last, ...].set(
147-
out_dict_second_last["policy"]
148-
)
149-
endog_grid_solved = endog_grid_solved.at[idx_rescaled_second_last, ...].set(
150-
out_dict_second_last["endog_grid"]
151-
)
152-
153131
# If candidates are also needed to returned we return them additionally to the solution containers.
154132
if debug_info["return_candidates"]:
155133
return (

src/dcegm/interfaces/inspect_solution.py

Lines changed: 137 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1+
import copy
2+
13
import jax.lax
24
import jax.numpy as jnp
35
import numpy as np
46

57
from dcegm.final_periods import solve_last_two_periods
68
from dcegm.law_of_motion import calc_cont_grids_next_period
79
from dcegm.pre_processing.sol_container import create_solution_container
10+
from dcegm.solve_single_period import solve_single_period
811

912

1013
def partially_solve(
@@ -28,6 +31,10 @@ def partially_solve(
2831
return_candidates: If True, additionally return candidate solutions before applying the upper envelope.
2932
3033
"""
34+
batch_info_internal = copy.deepcopy(batch_info)
35+
36+
if n_periods < 2:
37+
raise ValueError("You must at least solve for two periods.")
3138

3239
continuous_states_info = model_config["continuous_states_info"]
3340

@@ -38,7 +45,7 @@ def partially_solve(
3845
params=params,
3946
model_funcs=model_funcs,
4047
)
41-
48+
# Determine the last period we need to solve for.
4249
last_relevant_period = model_config["n_periods"] - n_periods
4350

4451
relevant_state_choices_mask = (
@@ -71,11 +78,21 @@ def partially_solve(
7178
)
7279
)
7380

81+
# Determine rescale idx for reduced solution
82+
rescale_idx = np.where(relevant_state_choices_mask)[0].min()
83+
7484
# Create debug information
7585
debug_info = {
7686
"return_candidates": return_candidates,
77-
"rescale_idx": np.where(relevant_state_choices_mask)[0].min(),
7887
}
88+
last_two_period_batch_info = batch_info_internal["last_two_period_info"]
89+
# Rescale the indexes to save of the last two periods:
90+
last_two_period_batch_info["idx_state_choices_final_period"] = (
91+
last_two_period_batch_info["idx_state_choices_final_period"] - rescale_idx
92+
)
93+
last_two_period_batch_info["idx_state_choices_second_last_period"] = (
94+
last_two_period_batch_info["idx_state_choices_second_last_period"] - rescale_idx
95+
)
7996
(
8097
value_solved,
8198
policy_solved,
@@ -89,25 +106,134 @@ def partially_solve(
89106
cont_grids_next_period=cont_grids_next_period,
90107
income_shock_weights=income_shock_weights,
91108
model_funcs=model_funcs,
92-
last_two_period_batch_info=batch_info["last_two_period_info"],
109+
last_two_period_batch_info=last_two_period_batch_info,
93110
value_solved=value_solved,
94111
policy_solved=policy_solved,
95112
endog_grid_solved=endog_grid_solved,
96113
debug_info=debug_info,
97114
)
98115
if return_candidates:
99-
idx_second_last = batch_info["last_two_period_info"][
116+
idx_second_last = batch_info_internal["last_two_period_info"][
100117
"idx_state_choices_second_last_period"
101118
]
102-
idx_second_last_rescaled = idx_second_last - debug_info["rescale_idx"]
103-
value_candidates = value_candidates.at[idx_second_last_rescaled, ...].set(
119+
value_candidates = value_candidates.at[idx_second_last, ...].set(
104120
value_candidates_second_last
105121
)
106-
policy_candidates = policy_candidates.at[idx_second_last_rescaled, ...].set(
122+
policy_candidates = policy_candidates.at[idx_second_last, ...].set(
107123
policy_candidates_second_last,
108124
)
109-
endog_grid_candidates = endog_grid_candidates.at[
110-
idx_second_last_rescaled, ...
111-
].set(endog_grid_candidates_second_last)
125+
endog_grid_candidates = endog_grid_candidates.at[idx_second_last, ...].set(
126+
endog_grid_candidates_second_last
127+
)
128+
129+
if n_periods <= 2:
130+
out_dict = {
131+
"value": value_solved,
132+
"policy": policy_solved,
133+
"endog_grid": endog_grid_solved,
134+
}
135+
if return_candidates:
136+
out_dict["value_candidates"] = value_candidates
137+
out_dict["policy_candidates"] = policy_candidates
138+
out_dict["endog_grid_candidates"] = endog_grid_candidates
139+
140+
return out_dict
141+
142+
stop_segment_loop = False
143+
for id_segment in range(batch_info_internal["n_segments"]):
144+
segment_info = batch_info_internal[f"batches_info_segment_{id_segment}"]
145+
146+
n_batches_in_segment = segment_info["batches_state_choice_idx"].shape[0]
112147

113-
return value_solved, policy_solved, endog_grid_solved
148+
for id_batch in range(n_batches_in_segment):
149+
periods_batch = segment_info["state_choices"]["period"][id_batch, :]
150+
151+
# Now there can be three cases:
152+
# 1) All periods are smaller than the last relevant period. Then we stop the loop
153+
# 2) Part of the periods are smaller than the last relevant period. Then we only solve for the partial state choices.
154+
# 3) All periods are larger than the last relevant period. Then we solve for state choices.
155+
if (periods_batch < last_relevant_period).all():
156+
stop_segment_loop = True
157+
break
158+
elif (periods_batch < last_relevant_period).any():
159+
solve_mask = periods_batch >= last_relevant_period
160+
state_choices_batch = {
161+
key: segment_info["state_choices"][key][id_batch, solve_mask]
162+
for key in segment_info["state_choices"].keys()
163+
}
164+
# We need to rescale the idx, because of saving
165+
idx_to_solve = (
166+
segment_info["batches_state_choice_idx"][id_batch, solve_mask]
167+
- rescale_idx
168+
)
169+
child_states_to_integrate_stochastic = segment_info[
170+
"child_states_to_integrate_stochastic"
171+
][id_batch, solve_mask, :]
172+
173+
else:
174+
state_choices_batch = {
175+
key: segment_info["state_choices"][key][id_batch, :]
176+
for key in segment_info["state_choices"].keys()
177+
}
178+
# We need to rescale the idx, because of saving
179+
idx_to_solve = (
180+
segment_info["batches_state_choice_idx"][id_batch, :] - rescale_idx
181+
)
182+
child_states_to_integrate_stochastic = segment_info[
183+
"child_states_to_integrate_stochastic"
184+
][id_batch, :, :]
185+
186+
state_choices_childs_batch = {
187+
key: segment_info["state_choices_childs"][key][id_batch, :]
188+
for key in segment_info["state_choices_childs"].keys()
189+
}
190+
xs = (
191+
idx_to_solve,
192+
segment_info["child_state_choices_to_aggr_choice"][id_batch, :, :],
193+
child_states_to_integrate_stochastic,
194+
segment_info["child_state_choice_idxs_to_interp"][id_batch, :],
195+
segment_info["child_states_idxs"][id_batch, :],
196+
state_choices_batch,
197+
state_choices_childs_batch,
198+
)
199+
carry = (value_solved, policy_solved, endog_grid_solved)
200+
single_period_out_dict = solve_single_period(
201+
carry=carry,
202+
xs=xs,
203+
params=params,
204+
continuous_grids_info=continuous_states_info,
205+
cont_grids_next_period=cont_grids_next_period,
206+
model_funcs=model_funcs,
207+
income_shock_weights=income_shock_weights,
208+
debug_info=debug_info,
209+
)
210+
211+
value_solved = single_period_out_dict["value"]
212+
policy_solved = single_period_out_dict["policy"]
213+
endog_grid_solved = single_period_out_dict["endog_grid"]
214+
215+
# If candidates are requested, we assign them to the solution container
216+
if return_candidates:
217+
value_candidates = value_candidates.at[idx_to_solve, ...].set(
218+
single_period_out_dict["value_candidates"]
219+
)
220+
policy_candidates = policy_candidates.at[idx_to_solve, ...].set(
221+
single_period_out_dict["policy_candidates"]
222+
)
223+
endog_grid_candidates = endog_grid_candidates.at[idx_to_solve, ...].set(
224+
single_period_out_dict["endog_grid_candidates"]
225+
)
226+
227+
if stop_segment_loop:
228+
break
229+
230+
out_dict = {
231+
"value": value_solved,
232+
"policy": policy_solved,
233+
"endog_grid": endog_grid_solved,
234+
}
235+
if return_candidates:
236+
out_dict["value_candidates"] = value_candidates
237+
out_dict["policy_candidates"] = policy_candidates
238+
out_dict["endog_grid_candidates"] = endog_grid_candidates
239+
return out_dict

src/dcegm/solve_single_period.py

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,46 +75,36 @@ def solve_single_period(
7575
model_funcs=model_funcs,
7676
debug_info=debug_info,
7777
)
78+
value_solved = value_solved.at[state_choices_idxs, :].set(out_dict_period["value"])
79+
policy_solved = policy_solved.at[state_choices_idxs, :].set(
80+
out_dict_period["policy"]
81+
)
82+
endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set(
83+
out_dict_period["endog_grid"]
84+
)
7885

79-
# If we are not in the debug mode, we only return the solved values.
86+
# If we are not in the debug mode, we only return the solution as a tuple and an empty tuple.
8087
if debug_info is None:
81-
82-
value_solved = value_solved.at[state_choices_idxs, :].set(
83-
out_dict_period["value"]
84-
)
85-
policy_solved = policy_solved.at[state_choices_idxs, :].set(
86-
out_dict_period["policy"]
87-
)
88-
endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set(
89-
out_dict_period["endog_grid"]
90-
)
9188
carry = (value_solved, policy_solved, endog_grid_solved)
89+
return carry, ()
9290

9391
else:
94-
if "rescale_idx" in debug_info.keys():
95-
state_choices_idxs = state_choices_idxs - debug_info["rescale_idx"]
96-
value_solved = value_solved.at[state_choices_idxs, :].set(
97-
out_dict_period["value"]
98-
)
99-
policy_solved = policy_solved.at[state_choices_idxs, :].set(
100-
out_dict_period["policy"]
101-
)
102-
endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set(
103-
out_dict_period["endog_grid"]
104-
)
105-
if debug_info["return_candidates"]:
106-
carry = (
107-
value_solved,
108-
policy_solved,
109-
endog_grid_solved,
110-
out_dict_period["value_candidates"],
111-
out_dict_period["policy_candidates"],
112-
out_dict_period["endog_grid_candidates"],
113-
)
114-
else:
115-
carry = (value_solved, policy_solved, endog_grid_solved)
92+
# In debug mode we return a dictionary.
93+
out_dict = {
94+
"value": value_solved,
95+
"policy": policy_solved,
96+
"endog_grid": endog_grid_solved,
97+
}
11698

117-
return carry, ()
99+
# If candidates are requested, we add them
100+
if debug_info["return_candidates"]:
101+
out_dict = {
102+
**out_dict,
103+
"value_candidates": out_dict_period["value_candidates"],
104+
"policy_candidates": out_dict_period["policy_candidates"],
105+
"endog_grid_candidates": out_dict_period["endog_grid_candidates"],
106+
}
107+
return out_dict
118108

119109

120110
def solve_for_interpolated_values(

0 commit comments

Comments
 (0)