@@ -39,7 +39,7 @@ def partially_solve(
39
39
model_funcs = model_funcs ,
40
40
)
41
41
42
- last_relevant_period = model_config ["n_periods" ] - n_periods - 1
42
+ last_relevant_period = model_config ["n_periods" ] - n_periods
43
43
44
44
relevant_state_choices_mask = (
45
45
model_structure ["state_choice_space" ][:, 0 ] >= last_relevant_period
@@ -53,14 +53,61 @@ def partially_solve(
53
53
policy_solved ,
54
54
endog_grid_solved ,
55
55
) = create_solution_container (
56
- model_config = model_config ,
56
+ continuous_states_info = model_config ["continuous_states_info" ],
57
+ # Read out grid size
58
+ n_total_wealth_grid = model_config ["tuning_params" ]["n_total_wealth_grid" ],
57
59
n_state_choices = relevant_state_choice_space .shape [0 ],
58
60
)
59
61
60
62
if return_candidates :
63
+ n_assets_end_of_period = model_config ["continuous_states_info" ][
64
+ "assets_grid_end_of_period"
65
+ ].shape [0 ]
61
66
(value_candidates , policy_candidates , endog_grid_candidates ) = (
62
67
create_solution_container (
63
- model_config = model_config ,
68
+ continuous_states_info = model_config ["continuous_states_info" ],
69
+ n_total_wealth_grid = n_assets_end_of_period ,
64
70
n_state_choices = relevant_state_choice_space .shape [0 ],
65
71
)
66
72
)
73
+
74
+ # Create debug information
75
+ debug_info = {
76
+ "return_candidates" : return_candidates ,
77
+ "rescale_idx" : np .where (relevant_state_choices_mask )[0 ].min (),
78
+ }
79
+ (
80
+ value_solved ,
81
+ policy_solved ,
82
+ endog_grid_solved ,
83
+ value_candidates_second_last ,
84
+ policy_candidates_second_last ,
85
+ endog_grid_candidates_second_last ,
86
+ ) = solve_last_two_periods (
87
+ params = params ,
88
+ continuous_states_info = continuous_states_info ,
89
+ cont_grids_next_period = cont_grids_next_period ,
90
+ income_shock_weights = income_shock_weights ,
91
+ model_funcs = model_funcs ,
92
+ last_two_period_batch_info = batch_info ["last_two_period_info" ],
93
+ value_solved = value_solved ,
94
+ policy_solved = policy_solved ,
95
+ endog_grid_solved = endog_grid_solved ,
96
+ debug_info = debug_info ,
97
+ )
98
+ if return_candidates :
99
+ idx_second_last = batch_info ["last_two_period_info" ][
100
+ "idx_state_choices_second_last_period"
101
+ ]
102
+ idx_second_last_rescaled = idx_second_last - debug_info ["rescale_idx" ]
103
+ value_candidates = value_candidates .at [idx_second_last_rescaled , ...].set (
104
+ value_candidates_second_last
105
+ )
106
+ policy_candidates = policy_candidates .at [idx_second_last_rescaled , ...].set (
107
+ policy_candidates_second_last ,
108
+ )
109
+ endog_grid_candidates = endog_grid_candidates .at [
110
+ idx_second_last_rescaled , ...
111
+ ].set (endog_grid_candidates_second_last )
112
+
113
+ return value_solved , policy_solved , endog_grid_solved
0 commit comments