Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions docs/source/background/two_period_model_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -749,15 +749,10 @@
]
},
{
"metadata": {},
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2025-06-30T09:22:14.977528Z",
"start_time": "2025-06-30T09:22:14.323938Z"
}
},
"outputs": [],
"execution_count": null,
"source": [
"state_dict = {\n",
" \"ltc\": initial_condition[\"health\"],\n",
Expand All @@ -767,9 +762,9 @@
"}\n",
"\n",
"\n",
"cons_calc, value = solved_model.value_and_policy_for_state_and_choice(\n",
" state=state_dict,\n",
" choice=choice_in_period_0,\n",
"cons_calc, value = solved_model.policy_and_value_for_states_and_choices(\n",
" states=state_dict,\n",
" choices=choice_in_period_0,\n",
")"
]
},
Expand Down
6 changes: 3 additions & 3 deletions src/dcegm/asset_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)


def adjust_observed_assets(observed_states_dict, params, model_class):
def adjust_observed_assets(observed_states_dict, params, model_class, aux_outs=False):
"""Correct observed beginning of period assets data for likelihood estimation.

Assets in empirical survey data is observed without the income of last period's
Expand Down Expand Up @@ -47,7 +47,7 @@ def adjust_observed_assets(observed_states_dict, params, model_class):
jnp.array(0.0, dtype=jnp.float64),
params,
model_funcs["compute_assets_begin_of_period"],
False,
aux_outs,
)

else:
Expand All @@ -60,7 +60,7 @@ def adjust_observed_assets(observed_states_dict, params, model_class):
jnp.array(0.0, dtype=jnp.float64),
params,
model_funcs["compute_assets_begin_of_period"],
False,
aux_outs,
)

return adjusted_assets
189 changes: 56 additions & 133 deletions src/dcegm/backward_induction.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Interface for the DC-EGM algorithm."""

from functools import partial
from typing import Any, Callable, Dict, Tuple

import jax
import jax.lax
import jax.numpy as jnp
import numpy as np

from dcegm.final_periods import solve_last_two_periods
from dcegm.law_of_motion import calc_cont_grids_next_period
from dcegm.pre_processing.sol_container import create_solution_container
from dcegm.solve_single_period import solve_single_period


Expand All @@ -25,117 +25,90 @@ def backward_induction(

Args:
params (dict): Dictionary containing the model parameters.
options (dict): Dictionary containing the model options.
period_specific_state_objects (np.ndarray): Dictionary containing
period-specific state and state-choice objects, with the following keys:
- "state_choice_mat" (jnp.ndarray)
- "idx_state_of_state_choice" (jnp.ndarray)
- "reshape_state_choice_vec_to_mat" (callable)
- "transform_between_state_and_state_choice_vec" (callable)
exog_savings_grid (np.ndarray): 1d array of shape (n_grid_wealth,)
containing the exogenous savings grid.
has_second_continuous_state (bool): Boolean indicating whether the model
features a second continuous state variable. If False, the only
continuous state variable is consumption/savings.
state_space (np.ndarray): 2d array of shape (n_states, n_state_variables + 1)
which serves as a collection of all possible states. By convention,
the first column must contain the period and the last column the
exogenous processes. Any other state variables are in between.
E.g. if the two state variables are period and lagged choice and all choices
are admissible in each period, the shape of the state space array is
(n_periods * n_choices, 3).
state_choice_space (np.ndarray): 2d array of shape
(n_feasible_states, n_state_and_exog_variables + 1) containing all
feasible state-choice combinations. By convention, the second to last
column contains the exogenous process. The last column always contains the
choice to be made (which is not a state variable).
income_shock_draws_unscaled (np.ndarray): 1d array of shape (n_quad_points,)
containing the Hermite quadrature points unscaled.
income_shock_weights (np.ndarrray): 1d array of shape
(n_stochastic_quad_points) with weights for each stoachstic shock draw.
n_periods (int): Number of periods.
model_funcs (dict): Dictionary containing following model functions:
- compute_marginal_utility (callable): User-defined function to compute the
agent's marginal utility. The input ```params``` is already partialled
in.
- compute_inverse_marginal_utility (Callable): Function for calculating the
inverse marginal utiFality, which takes the marginal utility as only
input.
- compute_next_period_wealth (callable): User-defined function to compute
the agent's wealth of the next period (t + 1). The inputs
```saving```, ```shock```, ```params``` and ```options```
are already partialled in.
- transition_vector_by_state (Callable): Partialled transition function
return transition vector for each state.
- final_period_partial (Callable): Partialled function for calculating the
consumption as well as value function and marginal utility in the final
period.
compute_upper_envelope (Callable): Function for calculating the upper
envelope of the policy and value function. If the number of discrete
choices is 1, this function is a dummy function that returns the policy
and value function as is, without performing a fast upper envelope
scan.
model_config (dict): Dictionary containing the model configuration.
model_funcs (dict): Dictionary containing model functions.
model_structure (dict): Dictionary containing model structure.
batch_info (dict): Dictionary containing batch information.

Returns:
dict: Dictionary containing the period-specific endog_grid, policy, and value
Tuple: Tuple containing the period-specific endog_grid, policy, and value
from the backward induction.

"""
continuous_states_info = model_config["continuous_states_info"]

cont_grids_next_period = calc_cont_grids_next_period(
model_structure=model_structure,
model_config=model_config,
income_shock_draws_unscaled=income_shock_draws_unscaled,
params=params,
model_funcs=model_funcs,
#
calc_grids_jit = jax.jit(
lambda income_shock_draws, params_inner: calc_cont_grids_next_period(
model_structure=model_structure,
model_config=model_config,
income_shock_draws_unscaled=income_shock_draws,
params=params_inner,
model_funcs=model_funcs,
)
)

# Create solution containers. The 20 percent extra in wealth grid needs to go
# into tuning parameters
n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"]
cont_grids_next_period = calc_grids_jit(income_shock_draws_unscaled, params)

(
value_solved,
policy_solved,
endog_grid_solved,
) = create_solution_container(
model_config=model_config,
model_structure=model_structure,
continuous_states_info=model_config["continuous_states_info"],
# Read out grid size
n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"],
n_state_choices=model_structure["state_choice_space"].shape[0],
)

# Solve the last two periods using lambda to capture static arguments
solve_last_two_period_jit = jax.jit(
lambda params_inner, cont_grids, weights, val_solved, pol_solved, endog_solved: solve_last_two_periods(
params=params_inner,
continuous_states_info=continuous_states_info,
cont_grids_next_period=cont_grids,
income_shock_weights=weights,
model_funcs=model_funcs,
last_two_period_batch_info=batch_info["last_two_period_info"],
value_solved=val_solved,
policy_solved=pol_solved,
endog_grid_solved=endog_solved,
debug_info=None,
)
)

# Solve the last two periods. We do this separately as the marginal utility of
# the child states in the last period is calculated from the marginal utility
# function of the bequest function, which might differ.
(
value_solved,
policy_solved,
endog_grid_solved,
) = solve_last_two_periods(
params=params,
continuous_states_info=continuous_states_info,
cont_grids_next_period=cont_grids_next_period,
income_shock_weights=income_shock_weights,
model_funcs=model_funcs,
last_two_period_batch_info=batch_info["last_two_period_info"],
value_solved=value_solved,
policy_solved=policy_solved,
endog_grid_solved=endog_grid_solved,
) = solve_last_two_period_jit(
params,
cont_grids_next_period,
income_shock_weights,
value_solved,
policy_solved,
endog_grid_solved,
)

# If it is a two period model we are done.
if batch_info["two_period_model"]:
return value_solved, policy_solved, endog_grid_solved

def partial_single_period(carry, xs):
return solve_single_period(
carry=carry,
xs=xs,
params=params,
continuous_grids_info=continuous_states_info,
cont_grids_next_period=cont_grids_next_period,
model_funcs=model_funcs,
income_shock_weights=income_shock_weights,
)
# Create JIT-compiled single period solver using lambda
partial_single_period = lambda carry, xs: solve_single_period(
carry=carry,
xs=xs,
params=params,
continuous_grids_info=continuous_states_info,
cont_grids_next_period=cont_grids_next_period,
model_funcs=model_funcs,
income_shock_weights=income_shock_weights,
debug_info=None,
)

for id_segment in range(batch_info["n_segments"]):
segment_info = batch_info[f"batches_info_segment_{id_segment}"]
Expand Down Expand Up @@ -192,53 +165,3 @@ def partial_single_period(carry, xs):
policy_solved,
endog_grid_solved,
)


def create_solution_container(
model_config: Dict[str, Any],
model_structure: Dict[str, Any],
):
"""Create solution containers for value, policy, and endog_grid."""

# Read out grid size
n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"]
n_state_choices = model_structure["state_choice_space"].shape[0]

# Check if second continuous state exists and read out array size
continuous_states_info = model_config["continuous_states_info"]
if continuous_states_info["second_continuous_exists"]:
n_second_continuous_grid = continuous_states_info["n_second_continuous_grid"]

value_solved = jnp.full(
(n_state_choices, n_second_continuous_grid, n_total_wealth_grid),
dtype=jnp.float64,
fill_value=jnp.nan,
)
policy_solved = jnp.full(
(n_state_choices, n_second_continuous_grid, n_total_wealth_grid),
dtype=jnp.float64,
fill_value=jnp.nan,
)
endog_grid_solved = jnp.full(
(n_state_choices, n_second_continuous_grid, n_total_wealth_grid),
dtype=jnp.float64,
fill_value=jnp.nan,
)
else:
value_solved = jnp.full(
(n_state_choices, n_total_wealth_grid),
dtype=jnp.float64,
fill_value=jnp.nan,
)
policy_solved = jnp.full(
(n_state_choices, n_total_wealth_grid),
dtype=jnp.float64,
fill_value=jnp.nan,
)
endog_grid_solved = jnp.full(
(n_state_choices, n_total_wealth_grid),
dtype=jnp.float64,
fill_value=jnp.nan,
)

return value_solved, policy_solved, endog_grid_solved
6 changes: 3 additions & 3 deletions src/dcegm/egm/aggregate_marginal_utility.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Tuple

import jax.numpy as jnp
import numpy as np


def aggregate_marg_utils_and_exp_values(
value_state_choice_specific: jnp.ndarray,
marg_util_state_choice_specific: jnp.ndarray,
reshape_state_choice_vec_to_mat: np.ndarray,
reshape_state_choice_vec_to_mat: jnp.ndarray,
taste_shock_scale,
taste_shock_scale_is_scalar,
income_shock_weights: jnp.ndarray,
Expand Down Expand Up @@ -47,11 +46,12 @@ def aggregate_marg_utils_and_exp_values(
mode="fill",
fill_value=jnp.nan,
)

# If taste shock is not scalar, we select from the array,
# where we have for each choice a taste shock scale one. They are by construction
# the same for all choices in a state
if not taste_shock_scale_is_scalar:
one_choice_per_state = np.min(reshape_state_choice_vec_to_mat, axis=1)
one_choice_per_state = jnp.min(reshape_state_choice_vec_to_mat, axis=1)
taste_shock_scale = jnp.take(
taste_shock_scale,
one_choice_per_state,
Expand Down
6 changes: 3 additions & 3 deletions src/dcegm/egm/interpolate_marginal_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def interp1d_value_and_marg_util_for_state_choice(
def interp_on_single_wealth_point(wealth_point):
policy_interp, value_interp = interp1d_policy_and_value_on_wealth(
wealth=wealth_point,
endog_grid=endog_grid_child_state_choice,
policy=policy_child_state_choice,
value=value_child_state_choice,
wealth_grid=endog_grid_child_state_choice,
policy_grid=policy_child_state_choice,
value_grid=value_child_state_choice,
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
Expand Down
Loading
Loading