1
+ import copy
2
+
1
3
import jax .lax
2
4
import jax .numpy as jnp
3
5
import numpy as np
4
6
5
7
from dcegm .final_periods import solve_last_two_periods
6
8
from dcegm .law_of_motion import calc_cont_grids_next_period
7
9
from dcegm .pre_processing .sol_container import create_solution_container
10
+ from dcegm .solve_single_period import solve_single_period
8
11
9
12
10
13
def partially_solve (
@@ -28,6 +31,10 @@ def partially_solve(
28
31
return_candidates: If True, additionally return candidate solutions before applying the upper envelope.
29
32
30
33
"""
34
+ batch_info_internal = copy .deepcopy (batch_info )
35
+
36
+ if n_periods < 2 :
37
+ raise ValueError ("You must at least solve for two periods." )
31
38
32
39
continuous_states_info = model_config ["continuous_states_info" ]
33
40
@@ -38,7 +45,7 @@ def partially_solve(
38
45
params = params ,
39
46
model_funcs = model_funcs ,
40
47
)
41
-
48
+ # Determine the last period we need to solve for.
42
49
last_relevant_period = model_config ["n_periods" ] - n_periods
43
50
44
51
relevant_state_choices_mask = (
@@ -71,11 +78,21 @@ def partially_solve(
71
78
)
72
79
)
73
80
81
+ # Determine rescale idx for reduced solution
82
+ rescale_idx = np .where (relevant_state_choices_mask )[0 ].min ()
83
+
74
84
# Create debug information
75
85
debug_info = {
76
86
"return_candidates" : return_candidates ,
77
- "rescale_idx" : np .where (relevant_state_choices_mask )[0 ].min (),
78
87
}
88
+ last_two_period_batch_info = batch_info_internal ["last_two_period_info" ]
89
+ # Rescale the indexes to save of the last two periods:
90
+ last_two_period_batch_info ["idx_state_choices_final_period" ] = (
91
+ last_two_period_batch_info ["idx_state_choices_final_period" ] - rescale_idx
92
+ )
93
+ last_two_period_batch_info ["idx_state_choices_second_last_period" ] = (
94
+ last_two_period_batch_info ["idx_state_choices_second_last_period" ] - rescale_idx
95
+ )
79
96
(
80
97
value_solved ,
81
98
policy_solved ,
@@ -89,25 +106,134 @@ def partially_solve(
89
106
cont_grids_next_period = cont_grids_next_period ,
90
107
income_shock_weights = income_shock_weights ,
91
108
model_funcs = model_funcs ,
92
- last_two_period_batch_info = batch_info [ "last_two_period_info" ] ,
109
+ last_two_period_batch_info = last_two_period_batch_info ,
93
110
value_solved = value_solved ,
94
111
policy_solved = policy_solved ,
95
112
endog_grid_solved = endog_grid_solved ,
96
113
debug_info = debug_info ,
97
114
)
98
115
if return_candidates :
99
- idx_second_last = batch_info ["last_two_period_info" ][
116
+ idx_second_last = batch_info_internal ["last_two_period_info" ][
100
117
"idx_state_choices_second_last_period"
101
118
]
102
- idx_second_last_rescaled = idx_second_last - debug_info ["rescale_idx" ]
103
- value_candidates = value_candidates .at [idx_second_last_rescaled , ...].set (
119
+ value_candidates = value_candidates .at [idx_second_last , ...].set (
104
120
value_candidates_second_last
105
121
)
106
- policy_candidates = policy_candidates .at [idx_second_last_rescaled , ...].set (
122
+ policy_candidates = policy_candidates .at [idx_second_last , ...].set (
107
123
policy_candidates_second_last ,
108
124
)
109
- endog_grid_candidates = endog_grid_candidates .at [
110
- idx_second_last_rescaled , ...
111
- ].set (endog_grid_candidates_second_last )
125
+ endog_grid_candidates = endog_grid_candidates .at [idx_second_last , ...].set (
126
+ endog_grid_candidates_second_last
127
+ )
128
+
129
+ if n_periods <= 2 :
130
+ out_dict = {
131
+ "value" : value_solved ,
132
+ "policy" : policy_solved ,
133
+ "endog_grid" : endog_grid_solved ,
134
+ }
135
+ if return_candidates :
136
+ out_dict ["value_candidates" ] = value_candidates
137
+ out_dict ["policy_candidates" ] = policy_candidates
138
+ out_dict ["endog_grid_candidates" ] = endog_grid_candidates
139
+
140
+ return out_dict
141
+
142
+ stop_segment_loop = False
143
+ for id_segment in range (batch_info_internal ["n_segments" ]):
144
+ segment_info = batch_info_internal [f"batches_info_segment_{ id_segment } " ]
145
+
146
+ n_batches_in_segment = segment_info ["batches_state_choice_idx" ].shape [0 ]
112
147
113
- return value_solved , policy_solved , endog_grid_solved
148
+ for id_batch in range (n_batches_in_segment ):
149
+ periods_batch = segment_info ["state_choices" ]["period" ][id_batch , :]
150
+
151
+ # Now there can be three cases:
152
+ # 1) All periods are smaller than the last relevant period. Then we stop the loop
153
+ # 2) Part of the periods are smaller than the last relevant period. Then we only solve for the partial state choices.
154
+ # 3) All periods are larger than the last relevant period. Then we solve for state choices.
155
+ if (periods_batch < last_relevant_period ).all ():
156
+ stop_segment_loop = True
157
+ break
158
+ elif (periods_batch < last_relevant_period ).any ():
159
+ solve_mask = periods_batch >= last_relevant_period
160
+ state_choices_batch = {
161
+ key : segment_info ["state_choices" ][key ][id_batch , solve_mask ]
162
+ for key in segment_info ["state_choices" ].keys ()
163
+ }
164
+ # We need to rescale the idx, because of saving
165
+ idx_to_solve = (
166
+ segment_info ["batches_state_choice_idx" ][id_batch , solve_mask ]
167
+ - rescale_idx
168
+ )
169
+ child_states_to_integrate_stochastic = segment_info [
170
+ "child_states_to_integrate_stochastic"
171
+ ][id_batch , solve_mask , :]
172
+
173
+ else :
174
+ state_choices_batch = {
175
+ key : segment_info ["state_choices" ][key ][id_batch , :]
176
+ for key in segment_info ["state_choices" ].keys ()
177
+ }
178
+ # We need to rescale the idx, because of saving
179
+ idx_to_solve = (
180
+ segment_info ["batches_state_choice_idx" ][id_batch , :] - rescale_idx
181
+ )
182
+ child_states_to_integrate_stochastic = segment_info [
183
+ "child_states_to_integrate_stochastic"
184
+ ][id_batch , :, :]
185
+
186
+ state_choices_childs_batch = {
187
+ key : segment_info ["state_choices_childs" ][key ][id_batch , :]
188
+ for key in segment_info ["state_choices_childs" ].keys ()
189
+ }
190
+ xs = (
191
+ idx_to_solve ,
192
+ segment_info ["child_state_choices_to_aggr_choice" ][id_batch , :, :],
193
+ child_states_to_integrate_stochastic ,
194
+ segment_info ["child_state_choice_idxs_to_interp" ][id_batch , :],
195
+ segment_info ["child_states_idxs" ][id_batch , :],
196
+ state_choices_batch ,
197
+ state_choices_childs_batch ,
198
+ )
199
+ carry = (value_solved , policy_solved , endog_grid_solved )
200
+ single_period_out_dict = solve_single_period (
201
+ carry = carry ,
202
+ xs = xs ,
203
+ params = params ,
204
+ continuous_grids_info = continuous_states_info ,
205
+ cont_grids_next_period = cont_grids_next_period ,
206
+ model_funcs = model_funcs ,
207
+ income_shock_weights = income_shock_weights ,
208
+ debug_info = debug_info ,
209
+ )
210
+
211
+ value_solved = single_period_out_dict ["value" ]
212
+ policy_solved = single_period_out_dict ["policy" ]
213
+ endog_grid_solved = single_period_out_dict ["endog_grid" ]
214
+
215
+ # If candidates are requested, we assign them to the solution container
216
+ if return_candidates :
217
+ value_candidates = value_candidates .at [idx_to_solve , ...].set (
218
+ single_period_out_dict ["value_candidates" ]
219
+ )
220
+ policy_candidates = policy_candidates .at [idx_to_solve , ...].set (
221
+ single_period_out_dict ["policy_candidates" ]
222
+ )
223
+ endog_grid_candidates = endog_grid_candidates .at [idx_to_solve , ...].set (
224
+ single_period_out_dict ["endog_grid_candidates" ]
225
+ )
226
+
227
+ if stop_segment_loop :
228
+ break
229
+
230
+ out_dict = {
231
+ "value" : value_solved ,
232
+ "policy" : policy_solved ,
233
+ "endog_grid" : endog_grid_solved ,
234
+ }
235
+ if return_candidates :
236
+ out_dict ["value_candidates" ] = value_candidates
237
+ out_dict ["policy_candidates" ] = policy_candidates
238
+ out_dict ["endog_grid_candidates" ] = endog_grid_candidates
239
+ return out_dict
0 commit comments