@@ -333,6 +333,63 @@ def get_child_states_and_calc_trans_probs(self, state, choice, params):
333
333
child_states_df ["trans_probs" ] = trans_probs
334
334
return child_states_df
335
335
336
+ def get_full_child_states_by_asset_id_and_probs (
337
+ self , state , choice , params , asset_id , second_continuous_id = None
338
+ ):
339
+ """Get the child states for a given state and choice and calculate the
340
+ transition probabilities."""
341
+ if "map_state_choice_to_child_states" not in self .model_structure :
342
+ raise ValueError (
343
+ "For this function the model needs to be created with debug_info='all'"
344
+ )
345
+
346
+ child_idx = get_child_state_index_per_state_choice (
347
+ states = state , choice = choice , model_structure = self .model_structure
348
+ )
349
+ state_space_dict = self .model_structure ["state_space_dict" ]
350
+ discrete_states_names = self .model_structure ["discrete_states_names" ]
351
+ child_states = {
352
+ key : state_space_dict [key ][child_idx ] for key in discrete_states_names
353
+ }
354
+ child_states_df = pd .DataFrame (child_states )
355
+
356
+ child_continuous_states = self .compute_law_of_motions (params = params )
357
+
358
+ if "second_continuous" in child_continuous_states .keys ():
359
+ if second_continuous_id is None :
360
+ raise ValueError ("second_continuous_id must be provided." )
361
+ else :
362
+ quad_wealth = child_continuous_states ["assets_begin_of_period" ][
363
+ child_idx , second_continuous_id , asset_id , :
364
+ ]
365
+ next_period_second_continuous = child_continuous_states [
366
+ "second_continuous"
367
+ ][child_idx , second_continuous_id ]
368
+
369
+ second_continuous_name = self .model_config ["continuous_states_info" ][
370
+ "second_continuous_state_name"
371
+ ]
372
+ child_states_df [second_continuous_name ] = next_period_second_continuous
373
+
374
+ else :
375
+ if second_continuous_id is not None :
376
+ raise ValueError ("second_continuous_id must not be provided." )
377
+ else :
378
+ quad_wealth = child_continuous_states ["assets_begin_of_period" ][
379
+ child_idx , asset_id , :
380
+ ]
381
+
382
+ for id_quad in range (quad_wealth .shape [1 ]):
383
+ child_states_df [f"assets_begin_of_period_quad_point_{ id_quad } " ] = (
384
+ quad_wealth [:, id_quad ]
385
+ )
386
+
387
+ trans_probs = self .model_funcs ["compute_stochastic_transition_vec" ](
388
+ params = params , choice = choice , ** state
389
+ )
390
+ child_states_df ["trans_probs" ] = trans_probs
391
+ return child_states_df
392
+
336
393
def compute_law_of_motions (self , params ):
337
394
return calc_cont_grids_next_period (
338
395
params = params ,
0 commit comments