Skip to content

Commit 0a574ca

Browse files
committed
Fix options.
1 parent cc734f5 commit 0a574ca

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

src/dcegm/interfaces/inspect_solution.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,17 @@ def partially_solve(
187187
key: segment_info["state_choices_childs"][key][id_batch, :]
188188
for key in segment_info["state_choices_childs"].keys()
189189
}
190+
191+
child_state_choice_idxs_to_interp = (
192+
segment_info["child_state_choice_idxs_to_interp"][id_batch, :]
193+
- rescale_idx
194+
)
195+
190196
xs = (
191197
idx_to_solve,
192198
segment_info["child_state_choices_to_aggr_choice"][id_batch, :, :],
193199
child_states_to_integrate_stochastic,
194-
segment_info["child_state_choice_idxs_to_interp"][id_batch, :],
200+
child_state_choice_idxs_to_interp,
195201
segment_info["child_states_idxs"][id_batch, :],
196202
state_choices_batch,
197203
state_choices_childs_batch,

src/dcegm/interfaces/model_class.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,63 @@ def get_child_states_and_calc_trans_probs(self, state, choice, params):
333333
child_states_df["trans_probs"] = trans_probs
334334
return child_states_df
335335

336+
def get_full_child_states_by_asset_id_and_probs(
337+
self, state, choice, params, asset_id, second_continuous_id=None
338+
):
339+
"""Get the child states for a given state and choice and calculate the
340+
transition probabilities."""
341+
if "map_state_choice_to_child_states" not in self.model_structure:
342+
raise ValueError(
343+
"For this function the model needs to be created with debug_info='all'"
344+
)
345+
346+
child_idx = get_child_state_index_per_state_choice(
347+
states=state, choice=choice, model_structure=self.model_structure
348+
)
349+
state_space_dict = self.model_structure["state_space_dict"]
350+
discrete_states_names = self.model_structure["discrete_states_names"]
351+
child_states = {
352+
key: state_space_dict[key][child_idx] for key in discrete_states_names
353+
}
354+
child_states_df = pd.DataFrame(child_states)
355+
356+
child_continuous_states = self.compute_law_of_motions(params=params)
357+
358+
if "second_continuous" in child_continuous_states.keys():
359+
if second_continuous_id is None:
360+
raise ValueError("second_continuous_id must be provided.")
361+
else:
362+
quad_wealth = child_continuous_states["assets_begin_of_period"][
363+
child_idx, second_continuous_id, asset_id, :
364+
]
365+
next_period_second_continuous = child_continuous_states[
366+
"second_continuous"
367+
][child_idx, second_continuous_id]
368+
369+
second_continuous_name = self.model_config["continuous_states_info"][
370+
"second_continuous_state_name"
371+
]
372+
child_states_df[second_continuous_name] = next_period_second_continuous
373+
374+
else:
375+
if second_continuous_id is not None:
376+
raise ValueError("second_continuous_id must not be provided.")
377+
else:
378+
quad_wealth = child_continuous_states["assets_begin_of_period"][
379+
child_idx, asset_id, :
380+
]
381+
382+
for id_quad in range(quad_wealth.shape[1]):
383+
child_states_df[f"assets_begin_of_period_quad_point_{id_quad}"] = (
384+
quad_wealth[:, id_quad]
385+
)
386+
387+
trans_probs = self.model_funcs["compute_stochastic_transition_vec"](
388+
params=params, choice=choice, **state
389+
)
390+
child_states_df["trans_probs"] = trans_probs
391+
return child_states_df
392+
336393
def compute_law_of_motions(self, params):
337394
return calc_cont_grids_next_period(
338395
params=params,

0 commit comments

Comments
 (0)