diff --git a/docs/source/background/two_period_model_tutorial.ipynb b/docs/source/background/two_period_model_tutorial.ipynb index 3d0fb177..572354ab 100644 --- a/docs/source/background/two_period_model_tutorial.ipynb +++ b/docs/source/background/two_period_model_tutorial.ipynb @@ -749,15 +749,10 @@ ] }, { + "metadata": {}, "cell_type": "code", - "execution_count": 17, - "metadata": { - "ExecuteTime": { - "end_time": "2025-06-30T09:22:14.977528Z", - "start_time": "2025-06-30T09:22:14.323938Z" - } - }, "outputs": [], + "execution_count": null, "source": [ "state_dict = {\n", " \"ltc\": initial_condition[\"health\"],\n", @@ -767,9 +762,9 @@ "}\n", "\n", "\n", - "cons_calc, value = solved_model.value_and_policy_for_state_and_choice(\n", - " state=state_dict,\n", - " choice=choice_in_period_0,\n", + "cons_calc, value = solved_model.policy_and_value_for_states_and_choices(\n", + " states=state_dict,\n", + " choices=choice_in_period_0,\n", ")" ] }, diff --git a/src/dcegm/asset_correction.py b/src/dcegm/asset_correction.py index f9bc2ddc..17a6e2b4 100644 --- a/src/dcegm/asset_correction.py +++ b/src/dcegm/asset_correction.py @@ -7,7 +7,7 @@ ) -def adjust_observed_assets(observed_states_dict, params, model_class): +def adjust_observed_assets(observed_states_dict, params, model_class, aux_outs=False): """Correct observed beginning of period assets data for likelihood estimation. Assets in empirical survey data is observed without the income of last period's @@ -47,7 +47,7 @@ def adjust_observed_assets(observed_states_dict, params, model_class): jnp.array(0.0, dtype=jnp.float64), params, model_funcs["compute_assets_begin_of_period"], - False, + aux_outs, ) else: @@ -60,7 +60,7 @@ def adjust_observed_assets(observed_states_dict, params, model_class): jnp.array(0.0, dtype=jnp.float64), params, model_funcs["compute_assets_begin_of_period"], - False, + aux_outs, ) return adjusted_assets diff --git a/src/dcegm/backward_induction.py b/src/dcegm/backward_induction.py index 470b46f7..a6bb1b87 100644 --- a/src/dcegm/backward_induction.py +++ b/src/dcegm/backward_induction.py @@ -1,14 +1,14 @@ """Interface for the DC-EGM algorithm.""" -from functools import partial from typing import Any, Callable, Dict, Tuple +import jax import jax.lax import jax.numpy as jnp -import numpy as np from dcegm.final_periods import solve_last_two_periods from dcegm.law_of_motion import calc_cont_grids_next_period +from dcegm.pre_processing.sol_container import create_solution_container from dcegm.solve_single_period import solve_single_period @@ -25,117 +25,90 @@ def backward_induction( Args: params (dict): Dictionary containing the model parameters. - options (dict): Dictionary containing the model options. - period_specific_state_objects (np.ndarray): Dictionary containing - period-specific state and state-choice objects, with the following keys: - - "state_choice_mat" (jnp.ndarray) - - "idx_state_of_state_choice" (jnp.ndarray) - - "reshape_state_choice_vec_to_mat" (callable) - - "transform_between_state_and_state_choice_vec" (callable) - exog_savings_grid (np.ndarray): 1d array of shape (n_grid_wealth,) - containing the exogenous savings grid. - has_second_continuous_state (bool): Boolean indicating whether the model - features a second continuous state variable. If False, the only - continuous state variable is consumption/savings. - state_space (np.ndarray): 2d array of shape (n_states, n_state_variables + 1) - which serves as a collection of all possible states. By convention, - the first column must contain the period and the last column the - exogenous processes. Any other state variables are in between. - E.g. if the two state variables are period and lagged choice and all choices - are admissible in each period, the shape of the state space array is - (n_periods * n_choices, 3). - state_choice_space (np.ndarray): 2d array of shape - (n_feasible_states, n_state_and_exog_variables + 1) containing all - feasible state-choice combinations. By convention, the second to last - column contains the exogenous process. The last column always contains the - choice to be made (which is not a state variable). income_shock_draws_unscaled (np.ndarray): 1d array of shape (n_quad_points,) containing the Hermite quadrature points unscaled. income_shock_weights (np.ndarrray): 1d array of shape (n_stochastic_quad_points) with weights for each stoachstic shock draw. - n_periods (int): Number of periods. - model_funcs (dict): Dictionary containing following model functions: - - compute_marginal_utility (callable): User-defined function to compute the - agent's marginal utility. The input ```params``` is already partialled - in. - - compute_inverse_marginal_utility (Callable): Function for calculating the - inverse marginal utiFality, which takes the marginal utility as only - input. - - compute_next_period_wealth (callable): User-defined function to compute - the agent's wealth of the next period (t + 1). The inputs - ```saving```, ```shock```, ```params``` and ```options``` - are already partialled in. - - transition_vector_by_state (Callable): Partialled transition function - return transition vector for each state. - - final_period_partial (Callable): Partialled function for calculating the - consumption as well as value function and marginal utility in the final - period. - compute_upper_envelope (Callable): Function for calculating the upper - envelope of the policy and value function. If the number of discrete - choices is 1, this function is a dummy function that returns the policy - and value function as is, without performing a fast upper envelope - scan. + model_config (dict): Dictionary containing the model configuration. + model_funcs (dict): Dictionary containing model functions. + model_structure (dict): Dictionary containing model structure. + batch_info (dict): Dictionary containing batch information. Returns: - dict: Dictionary containing the period-specific endog_grid, policy, and value + Tuple: Tuple containing the period-specific endog_grid, policy, and value from the backward induction. """ continuous_states_info = model_config["continuous_states_info"] - cont_grids_next_period = calc_cont_grids_next_period( - model_structure=model_structure, - model_config=model_config, - income_shock_draws_unscaled=income_shock_draws_unscaled, - params=params, - model_funcs=model_funcs, + # + calc_grids_jit = jax.jit( + lambda income_shock_draws, params_inner: calc_cont_grids_next_period( + model_structure=model_structure, + model_config=model_config, + income_shock_draws_unscaled=income_shock_draws, + params=params_inner, + model_funcs=model_funcs, + ) ) - # Create solution containers. The 20 percent extra in wealth grid needs to go - # into tuning parameters - n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"] + cont_grids_next_period = calc_grids_jit(income_shock_draws_unscaled, params) + ( value_solved, policy_solved, endog_grid_solved, ) = create_solution_container( - model_config=model_config, - model_structure=model_structure, + continuous_states_info=model_config["continuous_states_info"], + # Read out grid size + n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"], + n_state_choices=model_structure["state_choice_space"].shape[0], + ) + + # Solve the last two periods using lambda to capture static arguments + solve_last_two_period_jit = jax.jit( + lambda params_inner, cont_grids, weights, val_solved, pol_solved, endog_solved: solve_last_two_periods( + params=params_inner, + continuous_states_info=continuous_states_info, + cont_grids_next_period=cont_grids, + income_shock_weights=weights, + model_funcs=model_funcs, + last_two_period_batch_info=batch_info["last_two_period_info"], + value_solved=val_solved, + policy_solved=pol_solved, + endog_grid_solved=endog_solved, + debug_info=None, + ) ) - # Solve the last two periods. We do this separately as the marginal utility of - # the child states in the last period is calculated from the marginal utility - # function of the bequest function, which might differ. ( value_solved, policy_solved, endog_grid_solved, - ) = solve_last_two_periods( - params=params, - continuous_states_info=continuous_states_info, - cont_grids_next_period=cont_grids_next_period, - income_shock_weights=income_shock_weights, - model_funcs=model_funcs, - last_two_period_batch_info=batch_info["last_two_period_info"], - value_solved=value_solved, - policy_solved=policy_solved, - endog_grid_solved=endog_grid_solved, + ) = solve_last_two_period_jit( + params, + cont_grids_next_period, + income_shock_weights, + value_solved, + policy_solved, + endog_grid_solved, ) # If it is a two period model we are done. if batch_info["two_period_model"]: return value_solved, policy_solved, endog_grid_solved - def partial_single_period(carry, xs): - return solve_single_period( - carry=carry, - xs=xs, - params=params, - continuous_grids_info=continuous_states_info, - cont_grids_next_period=cont_grids_next_period, - model_funcs=model_funcs, - income_shock_weights=income_shock_weights, - ) + # Create JIT-compiled single period solver using lambda + partial_single_period = lambda carry, xs: solve_single_period( + carry=carry, + xs=xs, + params=params, + continuous_grids_info=continuous_states_info, + cont_grids_next_period=cont_grids_next_period, + model_funcs=model_funcs, + income_shock_weights=income_shock_weights, + debug_info=None, + ) for id_segment in range(batch_info["n_segments"]): segment_info = batch_info[f"batches_info_segment_{id_segment}"] @@ -192,53 +165,3 @@ def partial_single_period(carry, xs): policy_solved, endog_grid_solved, ) - - -def create_solution_container( - model_config: Dict[str, Any], - model_structure: Dict[str, Any], -): - """Create solution containers for value, policy, and endog_grid.""" - - # Read out grid size - n_total_wealth_grid = model_config["tuning_params"]["n_total_wealth_grid"] - n_state_choices = model_structure["state_choice_space"].shape[0] - - # Check if second continuous state exists and read out array size - continuous_states_info = model_config["continuous_states_info"] - if continuous_states_info["second_continuous_exists"]: - n_second_continuous_grid = continuous_states_info["n_second_continuous_grid"] - - value_solved = jnp.full( - (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - policy_solved = jnp.full( - (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - endog_grid_solved = jnp.full( - (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - else: - value_solved = jnp.full( - (n_state_choices, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - policy_solved = jnp.full( - (n_state_choices, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - endog_grid_solved = jnp.full( - (n_state_choices, n_total_wealth_grid), - dtype=jnp.float64, - fill_value=jnp.nan, - ) - - return value_solved, policy_solved, endog_grid_solved diff --git a/src/dcegm/egm/aggregate_marginal_utility.py b/src/dcegm/egm/aggregate_marginal_utility.py index 08ec2af5..a707ecda 100644 --- a/src/dcegm/egm/aggregate_marginal_utility.py +++ b/src/dcegm/egm/aggregate_marginal_utility.py @@ -1,13 +1,12 @@ from typing import Tuple import jax.numpy as jnp -import numpy as np def aggregate_marg_utils_and_exp_values( value_state_choice_specific: jnp.ndarray, marg_util_state_choice_specific: jnp.ndarray, - reshape_state_choice_vec_to_mat: np.ndarray, + reshape_state_choice_vec_to_mat: jnp.ndarray, taste_shock_scale, taste_shock_scale_is_scalar, income_shock_weights: jnp.ndarray, @@ -47,11 +46,12 @@ def aggregate_marg_utils_and_exp_values( mode="fill", fill_value=jnp.nan, ) + # If taste shock is not scalar, we select from the array, # where we have for each choice a taste shock scale one. They are by construction # the same for all choices in a state if not taste_shock_scale_is_scalar: - one_choice_per_state = np.min(reshape_state_choice_vec_to_mat, axis=1) + one_choice_per_state = jnp.min(reshape_state_choice_vec_to_mat, axis=1) taste_shock_scale = jnp.take( taste_shock_scale, one_choice_per_state, diff --git a/src/dcegm/egm/interpolate_marginal_utility.py b/src/dcegm/egm/interpolate_marginal_utility.py index 0d1e19e1..96217cdf 100644 --- a/src/dcegm/egm/interpolate_marginal_utility.py +++ b/src/dcegm/egm/interpolate_marginal_utility.py @@ -164,9 +164,9 @@ def interp1d_value_and_marg_util_for_state_choice( def interp_on_single_wealth_point(wealth_point): policy_interp, value_interp = interp1d_policy_and_value_on_wealth( wealth=wealth_point, - endog_grid=endog_grid_child_state_choice, - policy=policy_child_state_choice, - value=value_child_state_choice, + wealth_grid=endog_grid_child_state_choice, + policy_grid=policy_child_state_choice, + value_grid=value_child_state_choice, compute_utility=compute_utility, state_choice_vec=state_choice_vec, params=params, diff --git a/src/dcegm/egm/solve_euler_equation.py b/src/dcegm/egm/solve_euler_equation.py index b1a10c06..3cc3f8da 100644 --- a/src/dcegm/egm/solve_euler_equation.py +++ b/src/dcegm/egm/solve_euler_equation.py @@ -2,21 +2,20 @@ from typing import Callable, Dict, Tuple -import numpy as np from jax import numpy as jnp from jax import vmap def calculate_candidate_solutions_from_euler_equation( - continuous_grids_info: np.ndarray, + continuous_grids_info: jnp.ndarray, marg_util_next: jnp.ndarray, emax_next: jnp.ndarray, - state_choice_mat: np.ndarray, - idx_post_decision_child_states: np.ndarray, + state_choice_mat: jnp.ndarray, + idx_post_decision_child_states: jnp.ndarray, model_funcs: Dict[str, Callable], has_second_continuous_state: bool, params: Dict[str, float], -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Calculate candidates for the optimal policy and value function.""" feasible_marg_utils_child = jnp.take( @@ -78,14 +77,14 @@ def calculate_candidate_solutions_from_euler_equation( def compute_optimal_policy_and_value_wrapper( - marg_util_next: np.ndarray, - emax_next: np.ndarray, - second_continuous_grid: np.ndarray, - assets_grid_end_of_period: np.ndarray, + marg_util_next: jnp.ndarray, + emax_next: jnp.ndarray, + second_continuous_grid: jnp.ndarray, + assets_grid_end_of_period: jnp.ndarray, state_choice_vec: Dict, model_funcs: Dict[str, Callable], params: Dict[str, float], -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Write second continuous grid point into state_choice_vec.""" state_choice_vec["continuous_state"] = second_continuous_grid @@ -100,13 +99,13 @@ def compute_optimal_policy_and_value_wrapper( def compute_optimal_policy_and_value( - marg_util_next: np.ndarray, - emax_next: np.ndarray, - assets_grid_end_of_period: np.ndarray, + marg_util_next: jnp.ndarray, + emax_next: jnp.ndarray, + assets_grid_end_of_period: jnp.ndarray, state_choice_vec: Dict, model_funcs: Dict[str, Callable], params: Dict[str, float], -) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Compute optimal child-state- and choice-specific policy and value function. Given the marginal utilities of possible child states and next period wealth, we @@ -173,14 +172,14 @@ def compute_optimal_policy_and_value( def solve_euler_equation( state_choice_vec: dict, - marg_util_next: np.ndarray, - emax_next: np.ndarray, + marg_util_next: jnp.ndarray, + emax_next: jnp.ndarray, compute_inverse_marginal_utility: Callable, compute_stochastic_transition_vec: Callable, params: Dict[str, float], discount_factor: float, interest_rate: float, -) -> Tuple[np.ndarray, np.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Solve the Euler equation for given discrete choice and child states. We integrate over the exogenous process and income uncertainty and diff --git a/src/dcegm/final_periods.py b/src/dcegm/final_periods.py index e5a0a4e7..a42787d5 100644 --- a/src/dcegm/final_periods.py +++ b/src/dcegm/final_periods.py @@ -21,6 +21,7 @@ def solve_last_two_periods( value_solved, policy_solved, endog_grid_solved, + debug_info, ): """Solves the last two periods of the model. @@ -47,7 +48,6 @@ def solve_last_two_periods( for all states, end of period assets, and income shocks. """ - ( value_solved, policy_solved, @@ -86,7 +86,7 @@ def solve_last_two_periods( last_two_period_batch_info["state_choice_mat_final_period"], params ) - endog_grid, policy, value = solve_for_interpolated_values( + out_dict_second_last = solve_for_interpolated_values( value_interpolated=value_interp_final_period, marginal_utility_interpolated=marginal_utility_final_last_period, state_choice_mat=last_two_period_batch_info[ @@ -104,20 +104,48 @@ def solve_last_two_periods( income_shock_weights=income_shock_weights, continuous_grids_info=continuous_states_info, model_funcs=model_funcs, + debug_info=debug_info, ) idx_second_last = last_two_period_batch_info["idx_state_choices_second_last_period"] - value_solved = value_solved.at[idx_second_last, ...].set(value) - policy_solved = policy_solved.at[idx_second_last, ...].set(policy) - endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set(endog_grid) - - return ( - value_solved, - policy_solved, - endog_grid_solved, + value_solved = value_solved.at[idx_second_last, ...].set( + out_dict_second_last["value"] + ) + policy_solved = policy_solved.at[idx_second_last, ...].set( + out_dict_second_last["policy"] + ) + endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set( + out_dict_second_last["endog_grid"] ) + # If we do not call the function in debug mode. Assign everything and return + if debug_info is None: + return ( + value_solved, + policy_solved, + endog_grid_solved, + ) + + else: + # If candidates are also needed to returned we return them additionally to the solution containers. + if debug_info["return_candidates"]: + return ( + value_solved, + policy_solved, + endog_grid_solved, + out_dict_second_last["value_candidates"], + out_dict_second_last["policy_candidates"], + out_dict_second_last["endog_grid_candidates"], + ) + + else: + return ( + value_solved, + policy_solved, + endog_grid_solved, + ) + def solve_final_period( idx_state_choices_final_period, diff --git a/src/dcegm/interfaces/inspect_structure.py b/src/dcegm/interfaces/index_functions.py similarity index 56% rename from src/dcegm/interfaces/inspect_structure.py rename to src/dcegm/interfaces/index_functions.py index 2944e119..f4a84e7a 100644 --- a/src/dcegm/interfaces/inspect_structure.py +++ b/src/dcegm/interfaces/index_functions.py @@ -1,7 +1,9 @@ -def get_child_state_index_per_state_choice(states, choice, model_structure): - states_choice_dict = {**states, "choice": choice} - state_choice_index = get_state_choice_index_per_discrete_state_and_choice( - model_structure, states_choice_dict +import numpy as np + + +def get_child_state_index_per_states_and_choices(states, choices, model_structure): + state_choice_index = get_state_choice_index_per_discrete_states_and_choices( + model_structure, states, choices ) child_states = model_structure["map_state_choice_to_child_states"][ @@ -11,7 +13,7 @@ def get_child_state_index_per_state_choice(states, choice, model_structure): return child_states -def get_state_choice_index_per_discrete_state( +def get_state_choice_index_per_discrete_states( states, map_state_choice_to_index, discrete_states_names ): """Get the state-choice index for a given set of discrete states. @@ -29,30 +31,40 @@ def get_state_choice_index_per_discrete_state( indexes = map_state_choice_to_index[ tuple((states[key],) for key in discrete_states_names) ] + # Need flag to only evaluate in non jit mode + # max_values_per_state = {key: np.max(states[key]) for key in discrete_states_names} + # # Check that max value does not exceed the dimension + # dim = map_state_choice_to_index.shape + # for i, key in enumerate(discrete_states_names): + # if max_values_per_state[key] > dim[i] - 1: + # raise ValueError( + # f"Max value of state {key} exceeds the dimension of the model." + # ) + # As the code above generates a dummy dimension in the first index, remove it return indexes[0] -def get_state_choice_index_per_discrete_state_and_choice( - model_structure, state_choice_dict +def get_state_choice_index_per_discrete_states_and_choices( + model_structure, states, choices ): """Get the state-choice index for a given set of discrete states and a choice. Args: - model (dict): A dictionary representing the model. Must contain - 'model_structure' with a 'map_state_choice_to_index_with_proxy' - and 'discrete_states_names'. - state_choice_dict (dict): Dictionary containing discrete states and + model_structure (dict): Model structure containing all information on the structure of the model. + states (dict): Dictionary containing discrete states and the choice. Returns: int: The index corresponding to the specified discrete states and choice. """ + state_choices = {"choice": choices, **states} + map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] discrete_states_names = model_structure["discrete_states_names"] state_choice_tuple = tuple( - state_choice_dict[st] for st in discrete_states_names + ["choice"] + state_choices[st] for st in discrete_states_names + ["choice"] ) state_choice_index = map_state_choice_to_index[state_choice_tuple] diff --git a/src/dcegm/interfaces/inspect_solution.py b/src/dcegm/interfaces/inspect_solution.py new file mode 100644 index 00000000..d3e4e63c --- /dev/null +++ b/src/dcegm/interfaces/inspect_solution.py @@ -0,0 +1,245 @@ +import copy + +import jax.lax +import jax.numpy as jnp +import numpy as np + +from dcegm.final_periods import solve_last_two_periods +from dcegm.law_of_motion import calc_cont_grids_next_period +from dcegm.pre_processing.sol_container import create_solution_container +from dcegm.solve_single_period import solve_single_period + + +def partially_solve( + income_shock_draws_unscaled, + income_shock_weights, + model_config, + batch_info, + model_funcs, + model_structure, + params, + n_periods, + return_candidates=False, +): + """Partially solve the model for the last n_periods. + + This method allows for large models to only solve part of the model, to debug the solution process. + + Args: + params: Model parameters. + n_periods: Number of periods to solve. + return_candidates: If True, additionally return candidate solutions before applying the upper envelope. + + """ + batch_info_internal = copy.deepcopy(batch_info) + + if n_periods < 2: + raise ValueError("You must at least solve for two periods.") + + continuous_states_info = model_config["continuous_states_info"] + + cont_grids_next_period = calc_cont_grids_next_period( + model_structure=model_structure, + model_config=model_config, + income_shock_draws_unscaled=income_shock_draws_unscaled, + params=params, + model_funcs=model_funcs, + ) + # Determine the last period we need to solve for. + last_relevant_period = model_config["n_periods"] - n_periods + + relevant_state_choices_mask = ( + model_structure["state_choice_space"][:, 0] >= last_relevant_period + ) + relevant_state_choice_space = model_structure["state_choice_space"][ + relevant_state_choices_mask + ] + + ( + value_solved, + policy_solved, + endog_grid_solved, + ) = create_solution_container( + continuous_states_info=model_config["continuous_states_info"], + # Read out grid size + n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"], + n_state_choices=relevant_state_choice_space.shape[0], + ) + + if return_candidates: + n_assets_end_of_period = model_config["continuous_states_info"][ + "assets_grid_end_of_period" + ].shape[0] + (value_candidates, policy_candidates, endog_grid_candidates) = ( + create_solution_container( + continuous_states_info=model_config["continuous_states_info"], + n_total_wealth_grid=n_assets_end_of_period, + n_state_choices=relevant_state_choice_space.shape[0], + ) + ) + + # Determine rescale idx for reduced solution + rescale_idx = np.where(relevant_state_choices_mask)[0].min() + + # Create debug information + debug_info = { + "return_candidates": return_candidates, + } + last_two_period_batch_info = batch_info_internal["last_two_period_info"] + # Rescale the indexes to save of the last two periods: + last_two_period_batch_info["idx_state_choices_final_period"] = ( + last_two_period_batch_info["idx_state_choices_final_period"] - rescale_idx + ) + last_two_period_batch_info["idx_state_choices_second_last_period"] = ( + last_two_period_batch_info["idx_state_choices_second_last_period"] - rescale_idx + ) + ( + value_solved, + policy_solved, + endog_grid_solved, + value_candidates_second_last, + policy_candidates_second_last, + endog_grid_candidates_second_last, + ) = solve_last_two_periods( + params=params, + continuous_states_info=continuous_states_info, + cont_grids_next_period=cont_grids_next_period, + income_shock_weights=income_shock_weights, + model_funcs=model_funcs, + last_two_period_batch_info=last_two_period_batch_info, + value_solved=value_solved, + policy_solved=policy_solved, + endog_grid_solved=endog_grid_solved, + debug_info=debug_info, + ) + if return_candidates: + idx_second_last = batch_info_internal["last_two_period_info"][ + "idx_state_choices_second_last_period" + ] + value_candidates = value_candidates.at[idx_second_last, ...].set( + value_candidates_second_last + ) + policy_candidates = policy_candidates.at[idx_second_last, ...].set( + policy_candidates_second_last, + ) + endog_grid_candidates = endog_grid_candidates.at[idx_second_last, ...].set( + endog_grid_candidates_second_last + ) + + if n_periods <= 2: + out_dict = { + "value": value_solved, + "policy": policy_solved, + "endog_grid": endog_grid_solved, + } + if return_candidates: + out_dict["value_candidates"] = value_candidates + out_dict["policy_candidates"] = policy_candidates + out_dict["endog_grid_candidates"] = endog_grid_candidates + + return out_dict + + stop_segment_loop = False + for id_segment in range(batch_info_internal["n_segments"]): + segment_info = batch_info_internal[f"batches_info_segment_{id_segment}"] + + n_batches_in_segment = segment_info["batches_state_choice_idx"].shape[0] + + for id_batch in range(n_batches_in_segment): + periods_batch = segment_info["state_choices"]["period"][id_batch, :] + + # Now there can be three cases: + # 1) All periods are smaller than the last relevant period. Then we stop the loop + # 2) Part of the periods are smaller than the last relevant period. Then we only solve for the partial state choices. + # 3) All periods are larger than the last relevant period. Then we solve for state choices. + if (periods_batch < last_relevant_period).all(): + stop_segment_loop = True + break + elif (periods_batch < last_relevant_period).any(): + solve_mask = periods_batch >= last_relevant_period + state_choices_batch = { + key: segment_info["state_choices"][key][id_batch, solve_mask] + for key in segment_info["state_choices"].keys() + } + # We need to rescale the idx, because of saving + idx_to_solve = ( + segment_info["batches_state_choice_idx"][id_batch, solve_mask] + - rescale_idx + ) + child_states_to_integrate_stochastic = segment_info[ + "child_states_to_integrate_stochastic" + ][id_batch, solve_mask, :] + + else: + state_choices_batch = { + key: segment_info["state_choices"][key][id_batch, :] + for key in segment_info["state_choices"].keys() + } + # We need to rescale the idx, because of saving + idx_to_solve = ( + segment_info["batches_state_choice_idx"][id_batch, :] - rescale_idx + ) + child_states_to_integrate_stochastic = segment_info[ + "child_states_to_integrate_stochastic" + ][id_batch, :, :] + + state_choices_childs_batch = { + key: segment_info["state_choices_childs"][key][id_batch, :] + for key in segment_info["state_choices_childs"].keys() + } + + child_state_choice_idxs_to_interp = ( + segment_info["child_state_choice_idxs_to_interp"][id_batch, :] + - rescale_idx + ) + + xs = ( + idx_to_solve, + segment_info["child_state_choices_to_aggr_choice"][id_batch, :, :], + child_states_to_integrate_stochastic, + child_state_choice_idxs_to_interp, + segment_info["child_states_idxs"][id_batch, :], + state_choices_batch, + state_choices_childs_batch, + ) + carry = (value_solved, policy_solved, endog_grid_solved) + single_period_out_dict = solve_single_period( + carry=carry, + xs=xs, + params=params, + continuous_grids_info=continuous_states_info, + cont_grids_next_period=cont_grids_next_period, + model_funcs=model_funcs, + income_shock_weights=income_shock_weights, + debug_info=debug_info, + ) + + value_solved = single_period_out_dict["value"] + policy_solved = single_period_out_dict["policy"] + endog_grid_solved = single_period_out_dict["endog_grid"] + + # If candidates are requested, we assign them to the solution container + if return_candidates: + value_candidates = value_candidates.at[idx_to_solve, ...].set( + single_period_out_dict["value_candidates"] + ) + policy_candidates = policy_candidates.at[idx_to_solve, ...].set( + single_period_out_dict["policy_candidates"] + ) + endog_grid_candidates = endog_grid_candidates.at[idx_to_solve, ...].set( + single_period_out_dict["endog_grid_candidates"] + ) + + if stop_segment_loop: + break + + out_dict = { + "value": value_solved, + "policy": policy_solved, + "endog_grid": endog_grid_solved, + } + if return_candidates: + out_dict["value_candidates"] = value_candidates + out_dict["policy_candidates"] = policy_candidates + out_dict["endog_grid_candidates"] = endog_grid_candidates + return out_dict diff --git a/src/dcegm/interfaces/interface.py b/src/dcegm/interfaces/interface.py index a5da8813..965706ff 100644 --- a/src/dcegm/interfaces/interface.py +++ b/src/dcegm/interfaces/interface.py @@ -3,16 +3,16 @@ import jax import jax.numpy as jnp import pandas as pd +from jax import numpy as jnp -from dcegm.interpolation.interp1d import ( - interp1d_policy_and_value_on_wealth, - interp_policy_on_wealth, - interp_value_on_wealth, +from dcegm.interfaces.index_functions import ( + get_state_choice_index_per_discrete_states_and_choices, ) -from dcegm.interpolation.interp2d import ( - interp2d_policy_and_value_on_wealth_and_regular_grid, - interp2d_policy_on_wealth_and_regular_grid, - interp2d_value_on_wealth_and_regular_grid, +from dcegm.interfaces.interface_checks import check_states_and_choices +from dcegm.interpolation.interp_interfaces import ( + interpolate_policy_and_value_for_state_and_choice, + interpolate_policy_for_state_and_choice, + interpolate_value_for_state_and_choice, ) @@ -35,9 +35,9 @@ def get_n_state_choice_period(model): ) -def policy_and_value_for_state_choice_vec( +def policy_and_value_for_states_and_choices( states, - choice, + choices, params, endog_grid_solved, value_solved, @@ -67,68 +67,38 @@ def policy_and_value_for_state_choice_vec( choice. """ - # ToDo: Check if states contains relevant structure - map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - discrete_states_names = model_structure["discrete_states_names"] - - if "dummy_stochastic" in discrete_states_names: - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - - else: - state_choice_vec = { - **states, - "choice": choice, - } - - state_choice_tuple = tuple( - state_choice_vec[st] for st in discrete_states_names + ["choice"] + state_choices = check_states_and_choices( + states=states, choices=choices, model_structure=model_structure ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] - continuous_states_info = model_config["continuous_states_info"] - - compute_utility = model_funcs["compute_utility"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - if continuous_states_info["second_continuous_exists"]: - - second_continuous = state_choice_vec[ - continuous_states_info["second_continuous_state_name"] - ] - - policy, value = interp2d_policy_and_value_on_wealth_and_regular_grid( - regular_grid=continuous_states_info["second_continuous_grid"], - wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - value_grid=jnp.take(value_solved, state_choice_index, axis=0), - policy_grid=jnp.take(policy_solved, state_choice_index, axis=0), - regular_point_to_interp=second_continuous, - wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - else: - policy, value = interp1d_policy_and_value_on_wealth( - wealth=state_choice_vec["assets_begin_of_period"], - endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - policy=jnp.take(policy_solved, state_choice_index, axis=0), - value=jnp.take(value_solved, state_choice_index, axis=0), - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - return policy, value + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=state_choices, choices=choices, model_structure=model_structure + ) + endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) + value_grid_state_choice = jnp.take(value_solved, state_choice_idx, axis=0) + policy_grid_state_choice = jnp.take(policy_solved, state_choice_idx, axis=0) + + policy, value = jax.vmap( + interpolate_policy_and_value_for_state_and_choice, + in_axes=(0, 0, 0, 0, None, None, None), + )( + value_grid_state_choice, + policy_grid_state_choice, + endog_grid_state_choice, + state_choices, + params, + model_config, + model_funcs, + ) + return ( + jnp.squeeze(policy), + jnp.squeeze(value), + ) -def value_for_state_choice_vec( +def value_for_state_and_choice( states, - choice, + choices, params, endog_grid_solved, value_solved, @@ -153,65 +123,33 @@ def value_for_state_choice_vec( float: The value at the given state and choice. """ - map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - discrete_states_names = model_structure["discrete_states_names"] - - if "dummy_stochastic" in discrete_states_names: - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - - else: - state_choice_vec = { - **states, - "choice": choice, - } - - state_choice_tuple = tuple( - state_choice_vec[st] for st in discrete_states_names + ["choice"] + state_choices = check_states_and_choices( + states=states, choices=choices, model_structure=model_structure ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] - continuous_states_info = model_config["continuous_states_info"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - compute_utility = model_funcs["compute_utility"] - - if continuous_states_info["second_continuous_exists"]: - second_continuous = state_choice_vec[ - continuous_states_info["second_continuous_state_name"] - ] - - value = interp2d_value_on_wealth_and_regular_grid( - regular_grid=continuous_states_info["second_continuous_grid"], - wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - value_grid=jnp.take(value_solved, state_choice_index, axis=0), - regular_point_to_interp=second_continuous, - wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - else: - - value = interp_value_on_wealth( - wealth=state_choice_vec["assets_begin_of_period"], - endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - value=jnp.take(value_solved, state_choice_index, axis=0), - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - return value + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=state_choices, choices=choices, model_structure=model_structure + ) + endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) + value_grid_state_choice = jnp.take(value_solved, state_choice_idx, axis=0) + + value = jax.vmap( + interpolate_value_for_state_and_choice, + in_axes=(0, 0, 0, None, None, None), + )( + value_grid_state_choice, + endog_grid_state_choice, + state_choices, + params, + model_config, + model_funcs, + ) + return jnp.squeeze(value) def policy_for_state_choice_vec( states, - choice, + choices, endog_grid_solved, policy_solved, model_structure, @@ -233,51 +171,26 @@ def policy_for_state_choice_vec( float: The policy at the given state and choice. """ - map_state_choice_to_index = model_structure["map_state_choice_to_index_with_proxy"] - discrete_states_names = model_structure["discrete_states_names"] - - if "dummy_stochastic" in discrete_states_names: - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - - else: - state_choice_vec = { - **states, - "choice": choice, - } - - state_choice_tuple = tuple( - state_choice_vec[st] for st in discrete_states_names + ["choice"] + state_choices = check_states_and_choices( + states=states, choices=choices, model_structure=model_structure ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] - continuous_states_info = model_config["continuous_states_info"] - - if continuous_states_info["second_continuous_exists"]: - second_continuous = states[ - continuous_states_info["second_continuous_state_name"] - ] - - policy = interp2d_policy_on_wealth_and_regular_grid( - regular_grid=model_config["continuous_states_info"][ - "second_continuous_grid" - ], - wealth_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - policy_grid=jnp.take(policy_solved, state_choice_index, axis=0), - regular_point_to_interp=second_continuous, - wealth_point_to_interp=states["assets_begin_of_period"], - ) - - else: - policy = interp_policy_on_wealth( - wealth=states["assets_begin_of_period"], - endog_grid=jnp.take(endog_grid_solved, state_choice_index, axis=0), - policy=jnp.take(policy_solved, state_choice_index, axis=0), - ) - return policy + state_choice_idx = get_state_choice_index_per_discrete_states_and_choices( + states=state_choices, choices=choices, model_structure=model_structure + ) + endog_grid_state_choice = jnp.take(endog_grid_solved, state_choice_idx, axis=0) + policy_grid_state_choice = jnp.take(policy_solved, state_choice_idx, axis=0) + + policy = jax.vmap( + interpolate_policy_for_state_and_choice, + in_axes=(0, 0, 0, None), + )( + policy_grid_state_choice, + endog_grid_state_choice, + state_choices, + model_config, + ) + return jnp.squeeze(policy) def validate_stochastic_transition(params, model_config, model_funcs, model_structure): @@ -377,3 +290,117 @@ def stochastic_transition_vec(state_choice_vec_dict, func, params): """ return func(**state_choice_vec_dict, params=params) + + +def choice_values_for_states( + value_solved, + endog_grid_solved, + state_choice_indexes, + params, + states, + model_config, + model_funcs, +): + value_grid_states = jnp.take( + value_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + endog_grid_states = jnp.take( + endog_grid_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + + def wrapper_interp_value_for_choice( + state, + value_grid_state_choice, + endog_grid_state_choice, + choice, + ): + state_choice_vec = {**state, "choice": choice} + + return interpolate_value_for_state_and_choice( + value_grid_state_choice=value_grid_state_choice, + endog_grid_state_choice=endog_grid_state_choice, + state_choice_vec=state_choice_vec, + params=params, + model_config=model_config, + model_funcs=model_funcs, + ) + + # Read out choice range to loop over + choice_range = model_config["choices"] + + choice_values_per_state = jax.vmap( + jax.vmap( + wrapper_interp_value_for_choice, + in_axes=(None, 0, 0, 0), + ), + in_axes=(0, 0, 0, None), + )( + states, + value_grid_states, + endog_grid_states, + choice_range, + ) + return choice_values_per_state + + +def choice_policies_for_states( + policy_solved, + endog_grid_solved, + state_choice_indexes, + states, + model_config, +): + policy_grid_states = jnp.take( + policy_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + endog_grid_states = jnp.take( + endog_grid_solved, + state_choice_indexes, + axis=0, + mode="fill", + fill_value=jnp.nan, + ) + + def wrapper_interp_value_for_choice( + state, + policy_grid_state_choice, + endog_grid_state_choice, + choice, + ): + state_choice_vec = {**state, "choice": choice} + + return interpolate_policy_for_state_and_choice( + policy_grid_state_choice=policy_grid_state_choice, + endog_grid_state_choice=endog_grid_state_choice, + state_choice_vec=state_choice_vec, + model_config=model_config, + ) + + # Read out choice range to loop over + choice_range = model_config["choices"] + + choice_values_per_state = jax.vmap( + jax.vmap( + wrapper_interp_value_for_choice, + in_axes=(None, 0, 0, 0), + ), + in_axes=(0, 0, 0, None), + )( + states, + policy_grid_states, + endog_grid_states, + choice_range, + ) + return choice_values_per_state diff --git a/src/dcegm/interfaces/interface_checks.py b/src/dcegm/interfaces/interface_checks.py new file mode 100644 index 00000000..61412f17 --- /dev/null +++ b/src/dcegm/interfaces/interface_checks.py @@ -0,0 +1,123 @@ +import numpy as np +from jax import numpy as jnp + + +def check_states_and_choices(states, choices, model_structure): + """Check if the states and choices are valid according to the model structure. + + Args: + states (dict): Dictionary containing state values. + choices (int): The choice value. + model_structure (dict): Model structure containing information on + discrete states and choices. + + Returns: + state_choices (dict): Dictionary containing the states and choices, + + """ + discrete_states_names = model_structure["discrete_states_names"] + if "dummy_stochastic" in discrete_states_names: + if "dummy_stochastic" not in states.keys(): + need_to_add_dummy = True + # Check if all discrete states are present in states, except for the dummy stochastic state + observed_discrete_states = list( + set(discrete_states_names) - {"dummy_stochastic"} + ) + + else: + need_to_add_dummy = False + observed_discrete_states = discrete_states_names.copy() + + else: + need_to_add_dummy = False + observed_discrete_states = discrete_states_names.copy() + + if not all(state in states.keys() for state in observed_discrete_states): + raise ValueError("States should contain all discrete states.") + + # We start checking the dimensions: + # First check if the states are arrays or integers. If integers, all including choices need to be integers + # and we convert them to arrays. Determine first dimension of choice + if isinstance(choices, float): + raise ValueError("Choices should be integers or arrays, not floats. ") + # Check if choices is a single integer or numpy integers + elif isinstance(choices, (int, np.integer)): + choices = np.array([choices]) + single_state = True + # Now check if all states are integers + if not all( + isinstance(states[key], (int, np.integer)) + for key in observed_discrete_states + ): + raise ValueError( + "Discrete states should be integers or arrays. " + "As choices is a single integer, all states must be integers as well." + ) + else: + states = {key: np.array([value]) for key, value in states.items()} + + elif isinstance(choices, (np.ndarray, jnp.ndarray)): + if choices.ndim == 0: + # Check if choices has dtype int + if choices.dtype in [int, np.integer]: + raise ValueError( + "Choices should be integers or arrays with integer dtype." + ) + + choices = np.array([choices]) + single_state = True + # Now check if all states have dimension 0 as well + if not all(states[key].ndim == 0 for key in states.keys()): + raise ValueError( + "All states and choices must have the same dimension. Choices is dimension 0." + ) + # All observed discrete states must have dtype int as well + if not all( + states[key].dtype in [int, np.integer] + for key in observed_discrete_states + ): + raise ValueError( + "Discrete states should be integers or arrays with integer dtype. " + ) + states = {key: np.array([value]) for key, value in states.items()} + elif choices.ndim == 1: + # Check if choices has dtype int + if not np.issubdtype(choices.dtype, np.integer): + raise ValueError( + "Choices should be integers or arrays with integer dtype." + ) + single_state = False + # Check if all states are arrays with the same dimension as choices + if not all( + states[key].ndim == 1 and states[key].shape[0] == choices.shape[0] + for key in states.keys() + ): + raise ValueError( + "All states and choices must have the same dimension. Choices is dimension 1." + ) + # All observed discrete states must have dtype int as well + if not all( + np.issubdtype(states[key].dtype, np.integer) + for key in observed_discrete_states + ): + raise ValueError( + "Discrete states should be integers or arrays with integer dtype. " + ) + else: + raise ValueError( + "Choices should be integers or arrays with dimension 0 or 1." + ) + else: + raise ValueError("Choices should be integers or arrays with dimension 0 or 1.") + + if need_to_add_dummy: + if single_state: + states["dummy_stochastic"] = np.array([0]) + else: + states["dummy_stochastic"] = np.zeros(choices.shape[0], dtype=int) + + state_choices = { + **states, + "choice": choices, + } + return state_choices diff --git a/src/dcegm/interfaces/model_class.py b/src/dcegm/interfaces/model_class.py index 14450158..187f6f4f 100644 --- a/src/dcegm/interfaces/model_class.py +++ b/src/dcegm/interfaces/model_class.py @@ -6,9 +6,9 @@ import pandas as pd from dcegm.backward_induction import backward_induction -from dcegm.interfaces.inspect_structure import ( - get_child_state_index_per_state_choice, - get_state_choice_index_per_discrete_state, +from dcegm.interfaces.index_functions import ( + get_child_state_index_per_states_and_choices, + get_state_choice_index_per_discrete_states, ) from dcegm.interfaces.interface import validate_stochastic_transition from dcegm.interfaces.sol_interface import model_solved @@ -24,6 +24,7 @@ create_model_dict_and_save, load_model_dict, ) +from dcegm.pre_processing.shared import try_jax_array from dcegm.simulation.sim_utils import create_simulation_df from dcegm.simulation.simulate import simulate_all_periods @@ -46,7 +47,6 @@ def __init__( ): """Setup the model and check if load or save is required.""" - self.model_specs = model_specs if model_load_path is not None: model_dict = load_model_dict( model_config=model_config, @@ -85,6 +85,9 @@ def __init__( debug_info=debug_info, ) + self.model_specs = jax.tree_util.tree_map(try_jax_array, model_specs) + self.specs_without_jax = model_specs + self.model_config = model_dict["model_config"] self.model_funcs = model_dict["model_funcs"] self.model_structure = model_dict["model_structure"] @@ -99,6 +102,46 @@ def __init__( self.income_shock_draws_unscaled = income_shock_draws_unscaled self.income_shock_weights = income_shock_weights + if alternative_sim_specifications is not None: + self.alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=self.specs_without_jax, + model_specs_jax=self.model_specs, + **alternative_sim_specifications, + ) + else: + self.alternative_sim_funcs = None + + def set_alternative_sim_funcs( + self, alternative_sim_specifications, alternative_specs=None + ): + if alternative_specs is None: + self.alternative_sim_specs = self.model_specs + alternative_specs_without_jax = self.specs_without_jax + else: + self.alternative_sim_specs = jax.tree_util.tree_map( + try_jax_array, alternative_specs + ) + alternative_specs_without_jax = alternative_specs + + alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=alternative_specs_without_jax, + model_specs_jax=self.alternative_sim_specs, + **alternative_sim_specifications, + ) + self.alternative_sim_funcs = alternative_sim_funcs + + def backward_induction_inner_jit(self, params): + return backward_induction( + params=params, + income_shock_draws_unscaled=self.income_shock_draws_unscaled, + income_shock_weights=self.income_shock_weights, + model_config=self.model_config, + batch_info=self.batch_info, + model_funcs=self.model_funcs, + model_structure=self.model_structure, + ) + + def get_fast_solve_func(self): backward_jit = jax.jit( partial( backward_induction, @@ -111,14 +154,7 @@ def __init__( ) ) - self.backward_induction_jit = backward_jit - - if alternative_sim_specifications is not None: - self.alternative_sim_funcs = generate_alternative_sim_functions( - model_specs=model_specs, **alternative_sim_specifications - ) - else: - self.alternative_sim_funcs = None + return backward_jit def solve(self, params, load_sol_path=None, save_sol_path=None): """Solve a discrete-continuous life-cycle model using the DC-EGM algorithm. @@ -149,8 +185,9 @@ def solve(self, params, load_sol_path=None, save_sol_path=None): if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - # Solve the model - value, policy, endog_grid = self.backward_induction_jit(params_processed) + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) sol_dict = { "value": value, "policy": policy, @@ -195,8 +232,10 @@ def solve_and_simulate( if load_sol_path is not None: sol_dict = pkl.load(open(load_sol_path, "rb")) else: - # Solve the model - value, policy, endog_grid = self.backward_induction_jit(params_processed) + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) + sol_dict = { "value": value, "policy": policy, @@ -226,6 +265,7 @@ def get_solve_and_simulate_func( self, states_initial, seed, + slow_version=False, ): sim_func = lambda params, value, policy, endog_gid: simulate_all_periods( @@ -245,7 +285,9 @@ def get_solve_and_simulate_func( def solve_and_simulate_function_to_jit(params): params_processed = process_params(params, self.params_check_info) # Solve the model - value, policy, endog_grid = self.backward_induction_jit(params_processed) + value, policy, endog_grid = self.backward_induction_inner_jit( + params_processed + ) sim_dict = sim_func( params=params_processed, @@ -256,10 +298,13 @@ def solve_and_simulate_function_to_jit(params): return sim_dict - jit_solve_simulate = jax.jit(solve_and_simulate_function_to_jit) + if slow_version: + solve_simulate_func = solve_and_simulate_function_to_jit + else: + solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit) def solve_and_simulate_function(params): - sim_dict = jit_solve_simulate(params) + sim_dict = solve_simulate_func(params) df = create_simulation_df(sim_dict) return df @@ -273,6 +318,7 @@ def create_experimental_ll_func( unobserved_state_specs=None, return_model_solution=False, use_probability_of_observed_states=True, + slow_version=False, ): return create_individual_likelihood_function( @@ -280,13 +326,14 @@ def create_experimental_ll_func( model_config=self.model_config, model_funcs=self.model_funcs, model_specs=self.model_specs, - backwards_induction=self.backward_induction_jit, + backwards_induction_inner_jit=self.backward_induction_inner_jit, observed_states=observed_states, observed_choices=observed_choices, params_all=params_all, unobserved_state_specs=unobserved_state_specs, return_model_solution=return_model_solution, use_probability_of_observed_states=use_probability_of_observed_states, + slow_version=slow_version, ) def validate_exogenous(self, params): @@ -300,7 +347,7 @@ def validate_exogenous(self, params): def get_state_choices_idx(self, states): """Get the indices of the state choices for given states.""" - return get_state_choice_index_per_discrete_state( + return get_state_choice_index_per_discrete_states( states=states, map_state_choice_to_index=self.model_structure["map_state_choice_to_index"], discrete_states_names=self.model_structure["discrete_states_names"], @@ -312,8 +359,8 @@ def get_child_states(self, state, choice): "For this function the model needs to be created with debug_info='all'" ) - child_idx = get_child_state_index_per_state_choice( - states=state, choice=choice, model_structure=self.model_structure + child_idx = get_child_state_index_per_states_and_choices( + states=state, choices=choice, model_structure=self.model_structure ) state_space_dict = self.model_structure["state_space_dict"] discrete_states_names = self.model_structure["discrete_states_names"] @@ -322,6 +369,74 @@ def get_child_states(self, state, choice): } return pd.DataFrame(child_states) + def get_child_states_and_calc_trans_probs(self, state, choice, params): + """Get the child states for a given state and choice and calculate the + transition probabilities.""" + child_states_df = self.get_child_states(state, choice) + + trans_probs = self.model_funcs["compute_stochastic_transition_vec"]( + params=params, choice=choice, **state + ) + child_states_df["trans_probs"] = trans_probs + return child_states_df + + def get_full_child_states_by_asset_id_and_probs( + self, state, choice, params, asset_id, second_continuous_id=None + ): + """Get the child states for a given state and choice and calculate the + transition probabilities.""" + if "map_state_choice_to_child_states" not in self.model_structure: + raise ValueError( + "For this function the model needs to be created with debug_info='all'" + ) + + child_idx = get_child_state_index_per_states_and_choices( + states=state, choices=choice, model_structure=self.model_structure + ) + state_space_dict = self.model_structure["state_space_dict"] + discrete_states_names = self.model_structure["discrete_states_names"] + child_states = { + key: state_space_dict[key][child_idx] for key in discrete_states_names + } + child_states_df = pd.DataFrame(child_states) + + child_continuous_states = self.compute_law_of_motions(params=params) + + if "second_continuous" in child_continuous_states.keys(): + if second_continuous_id is None: + raise ValueError("second_continuous_id must be provided.") + else: + quad_wealth = child_continuous_states["assets_begin_of_period"][ + child_idx, second_continuous_id, asset_id, : + ] + next_period_second_continuous = child_continuous_states[ + "second_continuous" + ][child_idx, second_continuous_id] + + second_continuous_name = self.model_config["continuous_states_info"][ + "second_continuous_state_name" + ] + child_states_df[second_continuous_name] = next_period_second_continuous + + else: + if second_continuous_id is not None: + raise ValueError("second_continuous_id must not be provided.") + else: + quad_wealth = child_continuous_states["assets_begin_of_period"][ + child_idx, asset_id, : + ] + + for id_quad in range(quad_wealth.shape[1]): + child_states_df[f"assets_begin_of_period_quad_point_{id_quad}"] = ( + quad_wealth[:, id_quad] + ) + + trans_probs = self.model_funcs["compute_stochastic_transition_vec"]( + params=params, choice=choice, **state + ) + child_states_df["trans_probs"] = trans_probs + return child_states_df + def compute_law_of_motions(self, params): return calc_cont_grids_next_period( params=params, diff --git a/src/dcegm/interfaces/sol_interface.py b/src/dcegm/interfaces/sol_interface.py index f52b222a..d0f3270f 100644 --- a/src/dcegm/interfaces/sol_interface.py +++ b/src/dcegm/interfaces/sol_interface.py @@ -1,10 +1,25 @@ +import jax import jax.numpy as jnp +from dcegm.interfaces.index_functions import ( + get_state_choice_index_per_discrete_states_and_choices, +) from dcegm.interfaces.interface import ( - policy_and_value_for_state_choice_vec, + choice_policies_for_states, + choice_values_for_states, + policy_and_value_for_states_and_choices, policy_for_state_choice_vec, - value_for_state_choice_vec, + value_for_state_and_choice, +) +from dcegm.interfaces.interface_checks import check_states_and_choices +from dcegm.likelihood import ( + calc_choice_probs_for_states, + get_state_choice_index_per_discrete_states, +) +from dcegm.pre_processing.alternative_sim_functions import ( + generate_alternative_sim_functions, ) +from dcegm.pre_processing.shared import try_jax_array from dcegm.simulation.sim_utils import create_simulation_df from dcegm.simulation.simulate import simulate_all_periods @@ -32,8 +47,29 @@ def __init__( self.model_structure = model.model_structure self.model_funcs = model.model_funcs self.model_specs = model.model_specs + self.specs_without_jax = model.specs_without_jax self.alternative_sim_funcs = model.alternative_sim_funcs + def set_alternative_sim_funcs( + self, alternative_sim_specifications, alternative_specs=None + ): + if alternative_specs is None: + self.alternative_sim_specs = self.model_specs + alternative_specs_without_jax = self.specs_without_jax + else: + self.alternative_sim_specs = jax.tree_util.tree_map( + try_jax_array, alternative_specs + ) + alternative_specs_without_jax = alternative_specs + + alternative_sim_funcs = generate_alternative_sim_functions( + model_specs=alternative_specs_without_jax, + model_specs_jax=self.alternative_sim_specs, + **alternative_sim_specifications, + ) + self.model.alternative_sim_funcs = alternative_sim_funcs + self.alternative_sim_funcs = alternative_sim_funcs + def simulate(self, states_initial, seed): sim_dict = simulate_all_periods( @@ -51,20 +87,20 @@ def simulate(self, states_initial, seed): ) return create_simulation_df(sim_dict) - def value_and_policy_for_state_and_choice(self, state, choice): + def policy_and_value_for_states_and_choices(self, states, choices): """Get the value and policy for a given state and choice. Args: - state: The state for which to get the value and policy. - choice: The choice for which to get the value and policy. + states: The state for which to get the value and policy. + choices: The choice for which to get the value and policy. Returns: A tuple containing the value and policy for the given state and choice. """ - return policy_and_value_for_state_choice_vec( - states=state, - choice=choice, + return policy_and_value_for_states_and_choices( + states=states, + choices=choices, model_config=self.model_config, model_structure=self.model_structure, model_funcs=self.model_funcs, @@ -74,21 +110,21 @@ def value_and_policy_for_state_and_choice(self, state, choice): policy_solved=self.policy, ) - def value_for_state_and_choice(self, state, choice): + def value_for_states_and_choices(self, states, choices): """Get the value for a given state and choice. Args: - state: The state for which to get the value. - choice: The choice for which to get the value. + states: The state for which to get the value. + choices: The choice for which to get the value. Returns: The value for the given state and choice. """ - return value_for_state_choice_vec( - states=state, - choice=choice, + return value_for_state_and_choice( + states=states, + choices=choices, model_config=self.model_config, model_structure=self.model_structure, model_funcs=self.model_funcs, @@ -97,12 +133,12 @@ def value_for_state_and_choice(self, state, choice): value_solved=self.value, ) - def policy_for_state_and_choice(self, state, choice): + def policy_for_states_and_choices(self, states, choices): """Get the policy for a given state and choice. Args: - state: The state for which to get the policy. - choice: The choice for which to get the policy. + states: The state for which to get the policy. + choices: The choice for which to get the policy. Returns: The policy for the given state and choice. @@ -110,50 +146,125 @@ def policy_for_state_and_choice(self, state, choice): """ return policy_for_state_choice_vec( - states=state, - choice=choice, + states=states, + choices=choices, model_config=self.model_config, model_structure=self.model_structure, endog_grid_solved=self.endog_grid, policy_solved=self.policy, ) - def get_solution_for_discrete_state_choice(self, states, choice): + def get_solution_for_discrete_state_choice(self, states, choices): """Get the solution container for a given discrete state and choice combination. Args: states: The state for which to get the solution. - choice: The choice for which to get the solution. + choices: The choice for which to get the solution. Returns: A tuple containing the wealth grid, value grid, and policy grid for the given state and choice. """ - # Get the value and policy for a given state and choice. - - map_state_choice_to_index = self.model_structure[ - "map_state_choice_to_index_with_proxy" - ] - discrete_states_names = self.model_structure["discrete_states_names"] - - if "dummy_stochastic" in discrete_states_names: - state_choice_vec = { - **states, - "choice": choice, - "dummy_stochastic": 0, - } - else: - state_choice_vec = { - **states, - "choice": choice, - } + # Check if the states and choices are valid according to the model structure. + state_choices = check_states_and_choices( + states=states, + choices=choices, + model_structure=self.model_structure, + ) - state_choice_tuple = tuple( - state_choice_vec[state] for state in discrete_states_names + ["choice"] + # Get the value and policy for a given state and choice. We use state choices as states as it is not important + # that these are missing. + state_choice_index = get_state_choice_index_per_discrete_states_and_choices( + model_structure=self.model_structure, + states=state_choices, + choices=state_choices["choice"], ) - state_choice_index = map_state_choice_to_index[state_choice_tuple] endog_grid = jnp.take(self.endog_grid, state_choice_index, axis=0) value_grid = jnp.take(self.value, state_choice_index, axis=0) policy_grid = jnp.take(self.policy, state_choice_index, axis=0) return endog_grid, value_grid, policy_grid + + def choice_probabilities_for_states(self, states): + + # To check structure, add dummy choice for now and delete afterwards. + # Error messages will be misleading though. + state_choices = check_states_and_choices( + states=states, + choices=states["period"], + model_structure=self.model_structure, + ) + state_choices.pop("choice") + states = state_choices + + state_choice_idxs = get_state_choice_index_per_discrete_states( + states=states, + map_state_choice_to_index=self.model_structure[ + "map_state_choice_to_index_with_proxy" + ], + discrete_states_names=self.model_structure["discrete_states_names"], + ) + + return calc_choice_probs_for_states( + value_solved=self.value, + endog_grid_solved=self.endog_grid, + state_choice_indexes=state_choice_idxs, + params=self.params, + states=states, + model_config=self.model_config, + model_funcs=self.model_funcs, + ) + + def choice_values_for_states(self, states): + # To check structure, add dummy choice for now and delete afterwards. + # Error messages will be misleading though. + state_choices = check_states_and_choices( + states=states, + choices=states["period"], + model_structure=self.model_structure, + ) + state_choices.pop("choice") + states = state_choices + + state_choice_idxs = get_state_choice_index_per_discrete_states( + states=states, + map_state_choice_to_index=self.model_structure[ + "map_state_choice_to_index_with_proxy" + ], + discrete_states_names=self.model_structure["discrete_states_names"], + ) + return choice_values_for_states( + value_solved=self.value, + endog_grid_solved=self.endog_grid, + state_choice_indexes=state_choice_idxs, + params=self.params, + states=states, + model_config=self.model_config, + model_funcs=self.model_funcs, + ) + + def choice_policies_for_states(self, states): + # To check structure, add dummy choice for now and delete afterwards. + # Error messages will be misleading though. + state_choices = check_states_and_choices( + states=states, + choices=states["period"], + model_structure=self.model_structure, + ) + state_choices.pop("choice") + states = state_choices + + state_choice_idxs = get_state_choice_index_per_discrete_states( + states=states, + map_state_choice_to_index=self.model_structure[ + "map_state_choice_to_index_with_proxy" + ], + discrete_states_names=self.model_structure["discrete_states_names"], + ) + return choice_policies_for_states( + policy_solved=self.policy, + endog_grid_solved=self.endog_grid, + state_choice_indexes=state_choice_idxs, + states=states, + model_config=self.model_config, + ) diff --git a/src/dcegm/interpolation/interp1d.py b/src/dcegm/interpolation/interp1d.py index 4eed0272..6f7d73a9 100644 --- a/src/dcegm/interpolation/interp1d.py +++ b/src/dcegm/interpolation/interp1d.py @@ -38,9 +38,9 @@ def get_index_high_and_low(x, x_new): def interp1d_policy_and_value_on_wealth( wealth: float | jnp.ndarray, - endog_grid: jnp.ndarray, - policy: jnp.ndarray, - value: jnp.ndarray, + wealth_grid: jnp.ndarray, + policy_grid: jnp.ndarray, + value_grid: jnp.ndarray, compute_utility: Callable, state_choice_vec: Dict[str, int], params: Dict[str, float], @@ -50,9 +50,9 @@ def interp1d_policy_and_value_on_wealth( Args: wealth (float | jnp.ndarray): New wealth point(s) to interpolate. - endog_grid (jnp.ndarray): Solved endogenous wealth grid. - policy (jnp.ndarray): Solved policy function. - value (jnp.ndarray): Solved value function. + wealth_grid (jnp.ndarray): Solved endogenous wealth grid. + policy_grid (jnp.ndarray): Solved policy function. + value_grid (jnp.ndarray): Solved value function. state_choice_vec (Dict): Dictionary containing a single state and choice. params (Dict): Dictionary containing the model parameters. @@ -65,25 +65,25 @@ def interp1d_policy_and_value_on_wealth( """ # For all choices, the wealth is the same in the solution - ind_high, ind_low = get_index_high_and_low(x=endog_grid, x_new=wealth) + ind_high, ind_low = get_index_high_and_low(x=wealth_grid, x_new=wealth) policy_interp = linear_interpolation_formula( - y_high=policy[ind_high], - y_low=policy[ind_low], - x_high=endog_grid[ind_high], - x_low=endog_grid[ind_low], + y_high=policy_grid[ind_high], + y_low=policy_grid[ind_low], + x_high=wealth_grid[ind_high], + x_low=wealth_grid[ind_low], x_new=wealth, ) value_interp = interp_value_and_check_creditconstraint( - value_high=value[ind_high], - wealth_high=endog_grid[ind_high], - value_low=value[ind_low], - wealth_low=endog_grid[ind_low], + value_high=value_grid[ind_high], + wealth_high=wealth_grid[ind_high], + value_low=value_grid[ind_low], + wealth_low=wealth_grid[ind_low], new_wealth=wealth, compute_utility=compute_utility, - endog_grid_min=endog_grid[1], - value_at_zero_wealth=value[0], + endog_grid_min=wealth_grid[1], + value_at_zero_wealth=value_grid[0], state_choice_vec=state_choice_vec, params=params, discount_factor=discount_factor, @@ -94,7 +94,7 @@ def interp1d_policy_and_value_on_wealth( def interp_value_on_wealth( wealth: float | jnp.ndarray, - endog_grid: jnp.ndarray, + wealth_grid: jnp.ndarray, value: jnp.ndarray, compute_utility: Callable, state_choice_vec: Dict[str, int], @@ -105,7 +105,7 @@ def interp_value_on_wealth( Args: wealth (float): New wealth point to interpolate. - endog_grid (jnp.ndarray): Solved endogenous wealth grid. + wealth_grid (jnp.ndarray): Solved endogenous wealth grid. value (jnp.ndarray): Solved value function. state_choice_vec (Dict): Dictionary containing a single state and choice. params (Dict): Dictionary containing the model parameters. @@ -115,16 +115,16 @@ def interp_value_on_wealth( """ - ind_high, ind_low = get_index_high_and_low(x=endog_grid, x_new=wealth) + ind_high, ind_low = get_index_high_and_low(x=wealth_grid, x_new=wealth) value_interp = interp_value_and_check_creditconstraint( value_high=value[ind_high], - wealth_high=endog_grid[ind_high], + wealth_high=wealth_grid[ind_high], value_low=value[ind_low], - wealth_low=endog_grid[ind_low], + wealth_low=wealth_grid[ind_low], new_wealth=wealth, compute_utility=compute_utility, - endog_grid_min=endog_grid[1], + endog_grid_min=wealth_grid[1], value_at_zero_wealth=value[0], state_choice_vec=state_choice_vec, params=params, diff --git a/src/dcegm/interpolation/interp_interfaces.py b/src/dcegm/interpolation/interp_interfaces.py new file mode 100644 index 00000000..f0a66b0b --- /dev/null +++ b/src/dcegm/interpolation/interp_interfaces.py @@ -0,0 +1,133 @@ +from dcegm.interpolation.interp1d import ( + interp1d_policy_and_value_on_wealth, + interp_policy_on_wealth, + interp_value_on_wealth, +) +from dcegm.interpolation.interp2d import ( + interp2d_policy_and_value_on_wealth_and_regular_grid, + interp2d_policy_on_wealth_and_regular_grid, + interp2d_value_on_wealth_and_regular_grid, +) + + +def interpolate_value_for_state_and_choice( + value_grid_state_choice, + endog_grid_state_choice, + state_choice_vec, + params, + model_config, + model_funcs, +): + """Interpolate the value for a state and choice given the respective grids.""" + continuous_states_info = model_config["continuous_states_info"] + discount_factor = model_funcs["read_funcs"]["discount_factor"](params) + + compute_utility = model_funcs["compute_utility"] + + if continuous_states_info["second_continuous_exists"]: + second_continuous = state_choice_vec[ + continuous_states_info["second_continuous_state_name"] + ] + + value = interp2d_value_on_wealth_and_regular_grid( + regular_grid=continuous_states_info["second_continuous_grid"], + wealth_grid=endog_grid_state_choice, + value_grid=value_grid_state_choice, + regular_point_to_interp=second_continuous, + wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + else: + + value = interp_value_on_wealth( + wealth=state_choice_vec["assets_begin_of_period"], + wealth_grid=endog_grid_state_choice, + value=value_grid_state_choice, + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + return value + + +def interpolate_policy_for_state_and_choice( + policy_grid_state_choice, + endog_grid_state_choice, + state_choice_vec, + model_config, +): + """Interpolate the value for a state and choice given the respective grids.""" + continuous_states_info = model_config["continuous_states_info"] + + if continuous_states_info["second_continuous_exists"]: + second_continuous = state_choice_vec[ + continuous_states_info["second_continuous_state_name"] + ] + + policy = interp2d_policy_on_wealth_and_regular_grid( + regular_grid=continuous_states_info["second_continuous_grid"], + wealth_grid=endog_grid_state_choice, + policy_grid=policy_grid_state_choice, + regular_point_to_interp=second_continuous, + wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], + ) + + else: + policy = interp_policy_on_wealth( + wealth=state_choice_vec["assets_begin_of_period"], + endog_grid=endog_grid_state_choice, + policy=policy_grid_state_choice, + ) + + return policy + + +def interpolate_policy_and_value_for_state_and_choice( + value_grid_state_choice, + policy_grid_state_choice, + endog_grid_state_choice, + state_choice_vec, + params, + model_config, + model_funcs, +): + continuous_states_info = model_config["continuous_states_info"] + + compute_utility = model_funcs["compute_utility"] + discount_factor = model_funcs["read_funcs"]["discount_factor"](params) + + if continuous_states_info["second_continuous_exists"]: + + second_continuous = state_choice_vec[ + continuous_states_info["second_continuous_state_name"] + ] + + policy, value = interp2d_policy_and_value_on_wealth_and_regular_grid( + regular_grid=continuous_states_info["second_continuous_grid"], + wealth_grid=endog_grid_state_choice, + policy_grid=policy_grid_state_choice, + value_grid=value_grid_state_choice, + regular_point_to_interp=second_continuous, + wealth_point_to_interp=state_choice_vec["assets_begin_of_period"], + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + else: + policy, value = interp1d_policy_and_value_on_wealth( + wealth=state_choice_vec["assets_begin_of_period"], + wealth_grid=endog_grid_state_choice, + policy_grid=policy_grid_state_choice, + value_grid=value_grid_state_choice, + compute_utility=compute_utility, + state_choice_vec=state_choice_vec, + params=params, + discount_factor=discount_factor, + ) + + return policy, value diff --git a/src/dcegm/likelihood.py b/src/dcegm/likelihood.py index e7b284d1..3c9c66bc 100644 --- a/src/dcegm/likelihood.py +++ b/src/dcegm/likelihood.py @@ -5,7 +5,7 @@ """ import copy -from typing import Any, Dict +from typing import Dict import jax import jax.numpy as jnp @@ -15,9 +15,8 @@ from dcegm.egm.aggregate_marginal_utility import ( calculate_choice_probs_and_unsqueezed_logsum, ) -from dcegm.interfaces.inspect_structure import get_state_choice_index_per_discrete_state -from dcegm.interpolation.interp1d import interp_value_on_wealth -from dcegm.interpolation.interp2d import interp2d_value_on_wealth_and_regular_grid +from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states +from dcegm.interfaces.interface import choice_values_for_states def create_individual_likelihood_function( @@ -25,41 +24,33 @@ def create_individual_likelihood_function( model_config, model_funcs, model_specs, - backwards_induction, + backwards_induction_inner_jit, observed_states: Dict[str, int], - observed_choices: np.array, + observed_choices, params_all, unobserved_state_specs=None, return_model_solution=False, use_probability_of_observed_states=True, + slow_version=False, ): - if unobserved_state_specs is None: - choice_prob_func = create_partial_choice_prob_calculation( - observed_states=observed_states, - observed_choices=observed_choices, - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - ) - else: - - choice_prob_func = create_choice_prob_func_unobserved_states( - model_structure=model_structure, - model_config=model_config, - model_funcs=model_funcs, - model_specs=model_specs, - observed_states=observed_states, - observed_choices=observed_choices, - unobserved_state_specs=unobserved_state_specs, - use_probability_of_observed_states=use_probability_of_observed_states, - ) + choice_prob_func = create_choice_prob_function( + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + model_specs=model_specs, + observed_states=observed_states, + observed_choices=observed_choices, + unobserved_state_specs=unobserved_state_specs, + use_probability_of_observed_states=use_probability_of_observed_states, + return_weight_func=False, + ) def individual_likelihood(params): params_update = params_all.copy() params_update.update(params) - value, policy, endog_grid = backwards_induction(params_update) + value, policy, endog_grid = backwards_induction_inner_jit(params_update) choice_probs = choice_prob_func( value_in=value, @@ -80,7 +71,46 @@ def individual_likelihood(params): else: return neg_likelihood_contributions - return jax.jit(individual_likelihood) + if slow_version: + return individual_likelihood + else: + return jax.jit(individual_likelihood) + + +def create_choice_prob_function( + model_structure, + model_config, + model_funcs, + model_specs, + observed_states, + observed_choices, + unobserved_state_specs, + use_probability_of_observed_states, + return_weight_func, +): + if unobserved_state_specs is None: + choice_prob_func = create_partial_choice_prob_calculation( + observed_states=observed_states, + observed_choices=observed_choices, + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + ) + else: + + choice_prob_func = create_choice_prob_func_unobserved_states( + model_structure=model_structure, + model_config=model_config, + model_funcs=model_funcs, + model_specs=model_specs, + observed_states=observed_states, + observed_choices=observed_choices, + unobserved_state_specs=unobserved_state_specs, + use_probability_of_observed_states=use_probability_of_observed_states, + return_weight_func=return_weight_func, + ) + + return choice_prob_func def create_choice_prob_func_unobserved_states( @@ -89,19 +119,17 @@ def create_choice_prob_func_unobserved_states( model_funcs, model_specs, observed_states: Dict[str, int], - observed_choices: np.array, + observed_choices, unobserved_state_specs, use_probability_of_observed_states=True, + return_weight_func=False, ): unobserved_state_names = unobserved_state_specs["observed_bools_states"].keys() observed_bools = unobserved_state_specs["observed_bools_states"] # Create weighting vars by extracting states and choices - weighting_vars = unobserved_state_specs["state_choices_weighing"]["states"] - weighting_vars["choice"] = unobserved_state_specs["state_choices_weighing"][ - "choices" - ] + weighting_vars = unobserved_state_specs["weighting_vars"] # Add unobserved states with appendix new and bools indicating if state is observed for state_name in unobserved_state_names: @@ -148,7 +176,7 @@ def create_choice_prob_func_unobserved_states( for possible_state in possible_states: possible_state[state_name][unobserved_state_bool] = state_value new_possible_states.append(copy.deepcopy(possible_state)) - # Same for pre period states + # Same for variables to weight function for weighting_vars in weighting_vars_for_possible_states: weighting_vars[state_name + "_new"][unobserved_state_bool] = state_value new_weighting_vars_for_possible_states.append( @@ -168,6 +196,8 @@ def create_choice_prob_func_unobserved_states( observed_weights[observed_bools[state_name]] /= n_state_values + observed_weights = jnp.asarray(observed_weights) + # Create a list of partial choice probability functions for each unique # combination of unobserved states. partial_choice_probs_unobserved_states = [] @@ -192,6 +222,12 @@ def create_choice_prob_func_unobserved_states( n_obs = len(observed_choices) + # Use jax tree map to make only jax arrays of possible states and weighting vars + possible_states = jax.tree_util.tree_map(lambda x: jnp.asarray(x), possible_states) + weighting_vars_for_possible_states = jax.tree_util.tree_map( + lambda x: jnp.asarray(x), weighting_vars_for_possible_states + ) + def choice_prob_func(value_in, endog_grid_in, params_in): choice_probs_final = jnp.zeros(n_obs, dtype=jnp.float64) integrate_out_weights = jnp.zeros(n_obs, dtype=jnp.float64) @@ -227,7 +263,35 @@ def choice_prob_func(value_in, endog_grid_in, params_in): return choice_probs_final - return choice_prob_func + def weight_only_func(params_in): + weights = np.zeros((n_obs, len(possible_states)), dtype=np.float64) + count = 0 + for partial_choice_prob, unobserved_state, weighting_vars in zip( + partial_choice_probs_unobserved_states, + possible_states, + weighting_vars_for_possible_states, + ): + unobserved_weights = jax.vmap( + partial_weight_func, + in_axes=(None, 0), + )( + params_in, + weighting_vars, + ) + + weights[:, count] = unobserved_weights + count += 1 + return ( + weights, + observed_weights, + possible_states, + weighting_vars_for_possible_states, + ) + + if return_weight_func: + return choice_prob_func, weight_only_func + else: + return choice_prob_func def create_partial_choice_prob_calculation( @@ -237,7 +301,7 @@ def create_partial_choice_prob_calculation( model_config, model_funcs, ): - discrete_observed_state_choice_indexes = get_state_choice_index_per_discrete_state( + discrete_observed_state_choice_indexes = get_state_choice_index_per_discrete_states( states=observed_states, map_state_choice_to_index=model_structure[ "map_state_choice_to_index_with_proxy" @@ -276,12 +340,13 @@ def calc_choice_prob_for_state_choices( and then interpolates the wealth at the beginning of period on them. """ + choice_prob_across_choices = calc_choice_probs_for_states( value_solved=value_solved, endog_grid_solved=endog_grid_solved, - params=params, - observed_states=states, state_choice_indexes=state_choice_indexes, + params=params, + states=states, model_config=model_config, model_funcs=model_funcs, ) @@ -294,63 +359,21 @@ def calc_choice_prob_for_state_choices( def calc_choice_probs_for_states( value_solved, endog_grid_solved, - params, - observed_states, state_choice_indexes, + params, + states, model_config, model_funcs, ): - value_grid_agent = jnp.take( - value_solved, state_choice_indexes, axis=0, mode="fill", fill_value=jnp.nan + choice_values_per_state = choice_values_for_states( + value_solved=value_solved, + endog_grid_solved=endog_grid_solved, + state_choice_indexes=state_choice_indexes, + params=params, + states=states, + model_config=model_config, + model_funcs=model_funcs, ) - endog_grid_agent = jnp.take(endog_grid_solved, state_choice_indexes, axis=0) - - # Read out relevant model objects - continuous_states_info = model_config["continuous_states_info"] - choice_range = model_config["choices"] - - if continuous_states_info["second_continuous_exists"]: - vectorized_interp2d = jax.vmap( - jax.vmap( - interp2d_value_for_state_in_each_choice, - in_axes=(None, None, 0, 0, 0, None, None, None), - ), - in_axes=(0, 0, 0, 0, None, None, None, None), - ) - # Extract second cont state name - second_continuous_state_name = continuous_states_info[ - "second_continuous_state_name" - ] - second_cont_value = observed_states[second_continuous_state_name] - - value_per_agent_interp = vectorized_interp2d( - observed_states, - second_cont_value, - endog_grid_agent, - value_grid_agent, - choice_range, - params, - continuous_states_info["second_continuous_grid"], - model_funcs, - ) - - else: - vectorized_interp1d = jax.vmap( - jax.vmap( - interp1d_value_for_state_in_each_choice, - in_axes=(None, 0, 0, 0, None, None), - ), - in_axes=(0, 0, 0, None, None, None), - ) - - value_per_agent_interp = vectorized_interp1d( - observed_states, - endog_grid_agent, - value_grid_agent, - choice_range, - params, - model_funcs, - ) if model_funcs["taste_shock_function"]["taste_shock_scale_is_scalar"]: taste_shock_scale = model_funcs["taste_shock_function"][ @@ -361,72 +384,17 @@ def calc_choice_probs_for_states( "taste_shock_scale_per_state" ] taste_shock_scale = vmap(taste_shock_scale_per_state_func, in_axes=(0, None))( - observed_states, params + states, params ) taste_shock_scale = taste_shock_scale[:, None] choice_prob_across_choices, _, _ = calculate_choice_probs_and_unsqueezed_logsum( - choice_values_per_state=value_per_agent_interp, + choice_values_per_state=choice_values_per_state, taste_shock_scale=taste_shock_scale, ) return choice_prob_across_choices -def interp2d_value_for_state_in_each_choice( - state, - second_cont_state, - endog_grid_agent, - value_agent, - choice, - params, - regular_grid, - model_funcs, -): - state_choice_vec = {**state, "choice": choice} - - compute_utility = model_funcs["compute_utility"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - value_interp = interp2d_value_on_wealth_and_regular_grid( - regular_grid=regular_grid, - wealth_grid=endog_grid_agent, - value_grid=value_agent, - regular_point_to_interp=second_cont_state, - wealth_point_to_interp=state["assets_begin_of_period"], - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - - return value_interp - - -def interp1d_value_for_state_in_each_choice( - state, - endog_grid_agent, - value_agent, - choice, - params, - model_funcs, -): - state_choice_vec = {**state, "choice": choice} - compute_utility = model_funcs["compute_utility"] - discount_factor = model_funcs["read_funcs"]["discount_factor"](params) - - value_interp = interp_value_on_wealth( - wealth=state["assets_begin_of_period"], - endog_grid=endog_grid_agent, - value=value_agent, - compute_utility=compute_utility, - state_choice_vec=state_choice_vec, - params=params, - discount_factor=discount_factor, - ) - - return value_interp - - def calculate_weights_for_each_state(params, weight_vars, model_specs, weight_func): """Calculate the weights for each state. diff --git a/src/dcegm/numerical_integration.py b/src/dcegm/numerical_integration.py index 8c895449..014d1cbf 100644 --- a/src/dcegm/numerical_integration.py +++ b/src/dcegm/numerical_integration.py @@ -1,5 +1,6 @@ from typing import Tuple +import jax.numpy as jnp import numpy as np from scipy.special import roots_hermite, roots_sh_legendre from scipy.stats import norm @@ -33,10 +34,10 @@ def quadrature_hermite( quad_points_scaled = quad_points * np.sqrt(2) * income_shock_std quad_weights *= 1 / np.sqrt(np.pi) - return quad_points_scaled, quad_weights + return jnp.asarray(quad_points_scaled), jnp.asarray(quad_weights) -def quadrature_legendre(n_quad_points: int) -> Tuple[np.ndarray, np.ndarray]: +def quadrature_legendre(n_quad_points: int) -> Tuple[jnp.ndarray, jnp.ndarray]: """Return the Gauss-Legendre quadrature points and weights. The stochastic Gauss-Legendre quadrature points are shifted points @@ -58,4 +59,4 @@ def quadrature_legendre(n_quad_points: int) -> Tuple[np.ndarray, np.ndarray]: quad_points, quad_weights = roots_sh_legendre(n_quad_points) quad_points_normal = norm.ppf(quad_points) - return quad_points_normal, quad_weights + return jnp.asarray(quad_points_normal), jnp.asarray(quad_weights) diff --git a/src/dcegm/pre_processing/alternative_sim_functions.py b/src/dcegm/pre_processing/alternative_sim_functions.py index bc106881..2e7c60ce 100644 --- a/src/dcegm/pre_processing/alternative_sim_functions.py +++ b/src/dcegm/pre_processing/alternative_sim_functions.py @@ -26,6 +26,7 @@ def generate_alternative_sim_functions( model_config: Dict, model_specs: Dict, + model_specs_jax: Dict, state_space_functions: Dict[str, Callable], budget_constraint: Callable, shock_functions: Dict[str, Callable] = None, @@ -53,6 +54,7 @@ def generate_alternative_sim_functions( model_funcs, _ = process_alternative_sim_functions( model_config=model_config, model_specs=model_specs, + model_specs_jax=model_specs_jax, state_space_functions=state_space_functions, budget_constraint=budget_constraint, shock_functions=shock_functions, @@ -80,6 +82,7 @@ def generate_alternative_sim_functions( def process_alternative_sim_functions( model_config: Dict, model_specs: Dict, + model_specs_jax: Dict, stochastic_states_transition, state_space_functions: Dict[str, Callable], budget_constraint: Callable, @@ -138,7 +141,7 @@ def process_alternative_sim_functions( create_stochastic_transition_function( stochastic_states_transition, model_config=model_config, - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) @@ -154,7 +157,7 @@ def process_alternative_sim_functions( ) next_period_continuous_state = process_second_continuous_update_function( - second_continuous_state_name, state_space_functions, model_specs=model_specs + second_continuous_state_name, state_space_functions, model_specs=model_specs_jax ) # Budget equation @@ -162,7 +165,7 @@ def process_alternative_sim_functions( determine_function_arguments_and_partial_model_specs( func=budget_constraint, continuous_state_name=second_continuous_state_name, - model_specs=model_specs, + model_specs=model_specs_jax, ) ) @@ -174,8 +177,9 @@ def process_alternative_sim_functions( taste_shock_function_processed, taste_shock_scale_in_params = ( process_shock_functions( - shock_functions, - model_specs, + shock_functions=shock_functions, + model_specs=model_specs, + model_specs_jax=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) diff --git a/src/dcegm/pre_processing/check_model_config.py b/src/dcegm/pre_processing/check_model_config.py index ba41cb67..47083118 100644 --- a/src/dcegm/pre_processing/check_model_config.py +++ b/src/dcegm/pre_processing/check_model_config.py @@ -93,9 +93,9 @@ def check_model_config_and_process(model_config): second_continuous_state_name ) - second_continuous_state_grid = continuous_states_grids[ - second_continuous_state_name - ] + second_continuous_state_grid = jnp.asarray( + continuous_states_grids[second_continuous_state_name] + ) continuous_states_info["second_continuous_grid"] = second_continuous_state_grid # ToDo: Check if grid is array or list and monotonic increasing diff --git a/src/dcegm/pre_processing/check_model_specs.py b/src/dcegm/pre_processing/check_model_specs.py index b5281bac..098b49df 100644 --- a/src/dcegm/pre_processing/check_model_specs.py +++ b/src/dcegm/pre_processing/check_model_specs.py @@ -11,51 +11,41 @@ def extract_model_specs_info(model_specs): # discount_factor processing if "discount_factor" in model_specs: - read_func_discount_factor = lambda params: jnp.asarray( - model_specs["discount_factor"] - ) + discount_factor = jnp.asarray(model_specs["discount_factor"]) + read_func_discount_factor = lambda params: discount_factor discount_factor_in_params = False else: - read_func_discount_factor = lambda params: jnp.asarray( - params["discount_factor"] - ) + read_func_discount_factor = lambda params: params["discount_factor"] discount_factor_in_params = True # interest_rate processing if "interest_rate" in model_specs: # Check if interest_rate is a scalar - read_func_interest_rate = lambda params: jnp.asarray( - model_specs["interest_rate"] - ) + interest_rate = jnp.asarray(model_specs["interest_rate"]) + read_func_interest_rate = lambda params: interest_rate interest_rate_in_params = False else: - read_func_interest_rate = lambda params: jnp.asarray(params["interest_rate"]) + read_func_interest_rate = lambda params: params["interest_rate"] interest_rate_in_params = True # income shock std processing ("income_shock_std") if "income_shock_std" in model_specs: # Check if income_shock_std is a scalar - read_func_income_shock_std = lambda params: jnp.asarray( - model_specs["income_shock_std"] - ) + income_shock_std = jnp.asarray(model_specs["income_shock_std"]) + read_func_income_shock_std = lambda params: income_shock_std income_shock_std_in_params = False else: - read_func_income_shock_std = lambda params: jnp.asarray( - params["income_shock_std"] - ) + read_func_income_shock_std = lambda params: params["income_shock_std"] income_shock_std_in_params = True # income shock std processing ("income_shock_std") if "income_shock_mean" in model_specs: # Check if income_shock_std is a scalar - read_func_income_shock_mean = lambda params: jnp.asarray( - model_specs["income_shock_mean"] - ) + income_shock_mean = jnp.asarray(model_specs["income_shock_mean"]) + read_func_income_shock_mean = lambda params: income_shock_mean income_shock_mean_in_params = False else: - read_func_income_shock_mean = lambda params: jnp.asarray( - params["income_shock_mean"] - ) + read_func_income_shock_mean = lambda params: params["income_shock_mean"] income_shock_mean_in_params = True specs_read_funcs = { diff --git a/src/dcegm/pre_processing/model_functions/process_model_functions.py b/src/dcegm/pre_processing/model_functions/process_model_functions.py index eca89668..354eef4c 100644 --- a/src/dcegm/pre_processing/model_functions/process_model_functions.py +++ b/src/dcegm/pre_processing/model_functions/process_model_functions.py @@ -1,5 +1,6 @@ from typing import Callable, Dict, Optional +import jax import jax.numpy as jnp from dcegm.pre_processing.model_functions.taste_shock_function import ( @@ -13,6 +14,7 @@ ) from dcegm.pre_processing.shared import ( determine_function_arguments_and_partial_model_specs, + try_jax_array, ) @@ -70,23 +72,26 @@ def process_model_functions_and_extract_info( "second_continuous_state_name" ] + # We use this for functions which are called later in the jitted code + model_specs_jax = jax.tree_util.tree_map(try_jax_array, model_specs) + # Process mandatory functions. Start with utility functions compute_utility = determine_function_arguments_and_partial_model_specs( func=utility_functions["utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) compute_marginal_utility = determine_function_arguments_and_partial_model_specs( func=utility_functions["marginal_utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) compute_inverse_marginal_utility = ( determine_function_arguments_and_partial_model_specs( func=utility_functions["inverse_marginal_utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) @@ -99,14 +104,14 @@ def process_model_functions_and_extract_info( # Final period utility functions compute_utility_final = determine_function_arguments_and_partial_model_specs( func=utility_functions_final_period["utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) compute_marginal_utility_final = ( determine_function_arguments_and_partial_model_specs( func=utility_functions_final_period["marginal_utility"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) @@ -121,12 +126,12 @@ def process_model_functions_and_extract_info( create_stochastic_transition_function( stochastic_states_transitions, model_config=model_config, - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) - # Now state space functions + # Now state space functions - here we use the old model_specs state_specific_choice_set, next_period_deterministic_state, sparsity_condition = ( process_state_space_functions( state_space_functions, @@ -137,7 +142,7 @@ def process_model_functions_and_extract_info( ) next_period_continuous_state = process_second_continuous_update_function( - second_continuous_state_name, state_space_functions, model_specs=model_specs + second_continuous_state_name, state_space_functions, model_specs=model_specs_jax ) # Budget equation @@ -145,7 +150,7 @@ def process_model_functions_and_extract_info( determine_function_arguments_and_partial_model_specs( func=budget_constraint, continuous_state_name=second_continuous_state_name, - model_specs=model_specs, + model_specs=model_specs_jax, ) ) @@ -157,8 +162,9 @@ def process_model_functions_and_extract_info( taste_shock_function_processed, taste_shock_scale_in_params = ( process_shock_functions( - shock_functions, - model_specs, + shock_functions=shock_functions, + model_specs=model_specs, + model_specs_jax=model_specs_jax, continuous_state_name=second_continuous_state_name, ) ) diff --git a/src/dcegm/pre_processing/model_functions/taste_shock_function.py b/src/dcegm/pre_processing/model_functions/taste_shock_function.py index 13251fc6..d674f8c9 100644 --- a/src/dcegm/pre_processing/model_functions/taste_shock_function.py +++ b/src/dcegm/pre_processing/model_functions/taste_shock_function.py @@ -5,13 +5,15 @@ ) -def process_shock_functions(shock_functions, model_specs, continuous_state_name): +def process_shock_functions( + shock_functions, model_specs, model_specs_jax, continuous_state_name +): taste_shock_function_processed = {} shock_functions = {} if shock_functions is None else shock_functions if "taste_shock_scale_per_state" in shock_functions.keys(): taste_shock_scale_per_state = get_taste_shock_function_for_state( draw_function_taste_shocks=shock_functions["taste_shock_scale_per_state"], - model_specs=model_specs, + model_specs=model_specs_jax, continuous_state_name=continuous_state_name, ) taste_shock_function_processed["taste_shock_scale_per_state"] = ( @@ -28,10 +30,10 @@ def process_shock_functions(shock_functions, model_specs, continuous_state_name) f"Lambda is not a scalar. If there is no draw function provided, " f"lambda must be a scalar. Got {lambda_val}." ) - read_func = lambda params: jnp.asarray([model_specs["taste_shock_scale"]]) + read_func = lambda params: model_specs_jax["taste_shock_scale"] taste_shock_scale_in_params = False else: - read_func = lambda params: jnp.asarray([params["taste_shock_scale"]]) + read_func = lambda params: params["taste_shock_scale"] taste_shock_scale_in_params = True diff --git a/src/dcegm/pre_processing/setup_model.py b/src/dcegm/pre_processing/setup_model.py index a2842d35..49e16350 100644 --- a/src/dcegm/pre_processing/setup_model.py +++ b/src/dcegm/pre_processing/setup_model.py @@ -2,6 +2,7 @@ from typing import Callable, Dict import jax +import jax.numpy as jnp from dcegm.pre_processing.batches.batch_creation import create_batches_and_information from dcegm.pre_processing.check_model_config import check_model_config_and_process @@ -15,7 +16,10 @@ from dcegm.pre_processing.model_structure.stochastic_states import ( create_stochastic_state_mapping, ) -from dcegm.pre_processing.shared import create_array_with_smallest_int_dtype +from dcegm.pre_processing.shared import ( + create_array_with_smallest_int_dtype, + try_jax_array, +) def create_model_dict( @@ -109,12 +113,14 @@ def create_model_dict( model_structure.pop("map_state_choice_to_child_states") model_structure.pop("map_state_choice_to_index") + batch_info = jax.tree.map(create_array_with_smallest_int_dtype, batch_info) print("Model setup complete.\n") return { "model_config": model_config_processed, "model_funcs": model_funcs, - "model_structure": model_structure, - "batch_info": jax.tree.map(create_array_with_smallest_int_dtype, batch_info), + # Model structure are also lists, therefore we use try function + "model_structure": jax.tree.map(try_jax_array, model_structure), + "batch_info": jax.tree.map(jnp.asarray, batch_info), } diff --git a/src/dcegm/pre_processing/shared.py b/src/dcegm/pre_processing/shared.py index cffcf7e0..4ff716e9 100644 --- a/src/dcegm/pre_processing/shared.py +++ b/src/dcegm/pre_processing/shared.py @@ -65,3 +65,10 @@ def get_smallest_int_type(n_values): for dtype in uint_types: if np.iinfo(dtype).max >= n_values: return dtype + + +def try_jax_array(x): + try: + return jnp.asarray(x) + except: + return x diff --git a/src/dcegm/pre_processing/sol_container.py b/src/dcegm/pre_processing/sol_container.py new file mode 100644 index 00000000..68e4efe0 --- /dev/null +++ b/src/dcegm/pre_processing/sol_container.py @@ -0,0 +1,47 @@ +from typing import Any, Dict + +from jax import numpy as jnp + + +def create_solution_container( + continuous_states_info: Dict[str, Any], + n_total_wealth_grid: int, + n_state_choices: int, +): + """Create solution containers for value, policy, and endog_grid.""" + if continuous_states_info["second_continuous_exists"]: + n_second_continuous_grid = continuous_states_info["n_second_continuous_grid"] + + value_solved = jnp.full( + (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + policy_solved = jnp.full( + (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + endog_grid_solved = jnp.full( + (n_state_choices, n_second_continuous_grid, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + else: + value_solved = jnp.full( + (n_state_choices, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + policy_solved = jnp.full( + (n_state_choices, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + endog_grid_solved = jnp.full( + (n_state_choices, n_total_wealth_grid), + dtype=jnp.float64, + fill_value=jnp.nan, + ) + + return value_solved, policy_solved, endog_grid_solved diff --git a/src/dcegm/simulation/sim_utils.py b/src/dcegm/simulation/sim_utils.py index 3f3a802d..fe2e25b7 100644 --- a/src/dcegm/simulation/sim_utils.py +++ b/src/dcegm/simulation/sim_utils.py @@ -4,7 +4,7 @@ from jax import numpy as jnp from jax import vmap -from dcegm.interfaces.inspect_structure import get_state_choice_index_per_discrete_state +from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states from dcegm.interpolation.interp1d import interp1d_policy_and_value_on_wealth from dcegm.interpolation.interp2d import ( interp2d_policy_and_value_on_wealth_and_regular_grid, @@ -34,7 +34,7 @@ def interpolate_policy_and_value_for_all_agents( if continuous_state_beginning_of_period is not None: - discrete_state_choice_indexes = get_state_choice_index_per_discrete_state( + discrete_state_choice_indexes = get_state_choice_index_per_discrete_states( states=discrete_states_beginning_of_period, map_state_choice_to_index=map_state_choice_to_index, discrete_states_names=discrete_states_names, @@ -93,7 +93,7 @@ def interpolate_policy_and_value_for_all_agents( return policy_agent, value_agent else: - discrete_state_choice_indexes = get_state_choice_index_per_discrete_state( + discrete_state_choice_indexes = get_state_choice_index_per_discrete_states( states=discrete_states_beginning_of_period, map_state_choice_to_index=map_state_choice_to_index, discrete_states_names=discrete_states_names, @@ -303,9 +303,9 @@ def interp1d_policy_and_value_function( policy_interp, value_interp = interp1d_policy_and_value_on_wealth( wealth=wealth_beginning_of_period, - endog_grid=endog_grid_agent, - policy=policy_agent, - value=value_agent, + wealth_grid=endog_grid_agent, + policy_grid=policy_agent, + value_grid=value_agent, compute_utility=compute_utility, state_choice_vec=state_choice_vec, params=params, diff --git a/src/dcegm/simulation/simulate.py b/src/dcegm/simulation/simulate.py index 5886dc3f..c0ab9a92 100644 --- a/src/dcegm/simulation/simulate.py +++ b/src/dcegm/simulation/simulate.py @@ -7,7 +7,7 @@ import numpy as np from jax import vmap -from dcegm.interfaces.inspect_structure import get_state_choice_index_per_discrete_state +from dcegm.interfaces.index_functions import get_state_choice_index_per_discrete_states from dcegm.simulation.random_keys import draw_random_keys_for_seed from dcegm.simulation.sim_utils import ( compute_final_utility_for_each_choice, @@ -300,7 +300,7 @@ def simulate_final_period( params, compute_utility_final, ) - state_choice_indexes = get_state_choice_index_per_discrete_state( + state_choice_indexes = get_state_choice_index_per_discrete_states( states=states_begin_of_final_period, map_state_choice_to_index=map_state_choice_to_index, discrete_states_names=discrete_states_names, diff --git a/src/dcegm/solve_single_period.py b/src/dcegm/solve_single_period.py index ce7a4cd7..1ee2be5c 100644 --- a/src/dcegm/solve_single_period.py +++ b/src/dcegm/solve_single_period.py @@ -15,6 +15,7 @@ def solve_single_period( cont_grids_next_period, model_funcs, income_shock_weights, + debug_info, ): """Solve a single period of the model using DCEGM.""" (value_solved, policy_solved, endog_grid_solved) = carry @@ -60,31 +61,50 @@ def solve_single_period( state_choice_mat_child, params ) - endog_grid_state_choice, policy_state_choice, value_state_choice = ( - solve_for_interpolated_values( - value_interpolated=value_interpolated, - marginal_utility_interpolated=marginal_utility_interpolated, - state_choice_mat=state_choice_mat, - child_state_idxs=child_states_to_integrate_stochastic, - states_to_choices_child_states=child_state_choices_to_aggr_choice, - params=params, - taste_shock_scale=taste_shock_scale, - taste_shock_scale_is_scalar=taste_shock_scale_is_scalar, - income_shock_weights=income_shock_weights, - continuous_grids_info=continuous_grids_info, - model_funcs=model_funcs, - ) + out_dict_period = solve_for_interpolated_values( + value_interpolated=value_interpolated, + marginal_utility_interpolated=marginal_utility_interpolated, + state_choice_mat=state_choice_mat, + child_state_idxs=child_states_to_integrate_stochastic, + states_to_choices_child_states=child_state_choices_to_aggr_choice, + params=params, + taste_shock_scale=taste_shock_scale, + taste_shock_scale_is_scalar=taste_shock_scale_is_scalar, + income_shock_weights=income_shock_weights, + continuous_grids_info=continuous_grids_info, + model_funcs=model_funcs, + debug_info=debug_info, + ) + value_solved = value_solved.at[state_choices_idxs, :].set(out_dict_period["value"]) + policy_solved = policy_solved.at[state_choices_idxs, :].set( + out_dict_period["policy"] ) - - value_solved = value_solved.at[state_choices_idxs, :].set(value_state_choice) - policy_solved = policy_solved.at[state_choices_idxs, :].set(policy_state_choice) endog_grid_solved = endog_grid_solved.at[state_choices_idxs, :].set( - endog_grid_state_choice + out_dict_period["endog_grid"] ) - carry = (value_solved, policy_solved, endog_grid_solved) + # If we are not in the debug mode, we only return the solution as a tuple and an empty tuple. + if debug_info is None: + carry = (value_solved, policy_solved, endog_grid_solved) + return carry, () + + else: + # In debug mode we return a dictionary. + out_dict = { + "value": value_solved, + "policy": policy_solved, + "endog_grid": endog_grid_solved, + } - return carry, () + # If candidates are requested, we add them + if debug_info["return_candidates"]: + out_dict = { + **out_dict, + "value_candidates": out_dict_period["value_candidates"], + "policy_candidates": out_dict_period["policy_candidates"], + "endog_grid_candidates": out_dict_period["endog_grid_candidates"], + } + return out_dict def solve_for_interpolated_values( @@ -99,9 +119,10 @@ def solve_for_interpolated_values( income_shock_weights, continuous_grids_info, model_funcs, + debug_info, ): # EGM step 2) - # Aggregate the marginal utilities and expected values over all state-choice + # Aggregate the marginal utilities and expected values over all child state-choice # combinations and income shock draws marg_util, emax = aggregate_marg_utils_and_exp_values( value_state_choice_specific=value_interpolated, @@ -150,12 +171,20 @@ def solve_for_interpolated_values( has_second_continuous_state=continuous_grids_info["second_continuous_exists"], compute_upper_envelope_for_state_choice=model_funcs["compute_upper_envelope"], ) + out_dict = { + "endog_grid": endog_grid_state_choice, + "policy": policy_state_choice, + "value": value_state_choice, + } - return ( - endog_grid_state_choice, - policy_state_choice, - value_state_choice, - ) + # If candidates are requested, we additionally return them in the output dictionary. + if debug_info is not None: + if debug_info["return_candidates"]: + out_dict["endog_grid_candidates"] = endog_grid_candidate + out_dict["policy_candidates"] = policy_candidate + out_dict["value_candidates"] = value_candidate + + return out_dict def run_upper_envelope( diff --git a/tests/sandbox/discrete_versus_continuous_experience.ipynb b/tests/sandbox/discrete_versus_continuous_experience.ipynb deleted file mode 100644 index 71a1b96e..00000000 --- a/tests/sandbox/discrete_versus_continuous_experience.ipynb +++ /dev/null @@ -1,47 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "0ed8eee6-6946-46ed-9064-f0be2aebb19c", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import jax.numpy as jnp\n", - "import jax" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dd0bbeea-7c35-45ff-ba12-bb4bcfce97bf", - "metadata": {}, - "outputs": [], - "source": [ - "OPTIONS = {}" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/tests/sandbox/time_functions_jax.ipynb b/tests/sandbox/time_functions_jax.ipynb index d888f51c..c642478d 100644 --- a/tests/sandbox/time_functions_jax.ipynb +++ b/tests/sandbox/time_functions_jax.ipynb @@ -1,12 +1,108 @@ { "cells": [ + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:32:08.943565Z", + "start_time": "2025-08-15T09:32:08.305907Z" + } + }, + "cell_type": "code", + "source": [ + "import jax\n", + "import numpy as np" + ], + "id": "3939a0c83c7a102b", + "outputs": [], + "execution_count": 1 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:33:06.685423Z", + "start_time": "2025-08-15T09:33:06.572482Z" + } + }, + "cell_type": "code", + "source": [ + "def func_a(x, y):\n", + " return x + y\n", + "\n", + "jax.vmap(func_a, in_axes=(0, None))(np.array([2]), 3)" + ], + "id": "83f45f46db8be341", + "outputs": [ + { + "data": { + "text/plain": [ + "Array([5], dtype=int32)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 5 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:52:23.042715Z", + "start_time": "2025-08-15T09:52:23.037510Z" + } + }, + "cell_type": "code", + "source": "isinstance(np.array(2), np.ndarray)", + "id": "d2b3690f1f318672", + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 6 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-15T09:54:14.537765Z", + "start_time": "2025-08-15T09:54:14.530364Z" + } + }, + "cell_type": "code", + "source": [ + "# Check if array is a scalar\n", + "np.array(2).dtype == int" + ], + "id": "d66fc223e808a7e2", + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 10 + }, { "cell_type": "code", "id": "35ab80e3", "metadata": { "ExecuteTime": { - "end_time": "2025-04-10T09:20:40.430908Z", - "start_time": "2025-04-10T09:20:39.276843Z" + "end_time": "2025-08-15T09:52:31.700874Z", + "start_time": "2025-08-15T09:52:30.996171Z" } }, "source": [ @@ -20,8 +116,20 @@ "import numpy as np\n", "from tests.utils.markov_simulator import markov_simulator" ], - "outputs": [], - "execution_count": 2 + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'tests'", + "output_type": "error", + "traceback": [ + "\u001B[31m---------------------------------------------------------------------------\u001B[39m", + "\u001B[31mModuleNotFoundError\u001B[39m Traceback (most recent call last)", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[8]\u001B[39m\u001B[32m, line 9\u001B[39m\n\u001B[32m 7\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjax\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mjnp\u001B[39;00m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnumpy\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mas\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mnp\u001B[39;00m\n\u001B[32m----> \u001B[39m\u001B[32m9\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtests\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mutils\u001B[39;00m\u001B[34;01m.\u001B[39;00m\u001B[34;01mmarkov_simulator\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m markov_simulator\n", + "\u001B[31mModuleNotFoundError\u001B[39m: No module named 'tests'" + ] + } + ], + "execution_count": 8 }, { "metadata": { @@ -197,19 +305,19 @@ "evalue": "len() of unsized object", "output_type": "error", "traceback": [ - "\u001b[31m---------------------------------------------------------------------------\u001b[39m", - "\u001b[31mIndexError\u001b[39m Traceback (most recent call last)", - "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1895\u001b[39m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mshape\u001b[49m\u001b[43m[\u001b[49m\u001b[32;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n", - "\u001b[31mIndexError\u001b[39m: tuple index out of range", + "\u001B[31m---------------------------------------------------------------------------\u001B[39m", + "\u001B[31mIndexError\u001B[39m Traceback (most recent call last)", + "\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1896\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1895\u001B[39m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[32m-> \u001B[39m\u001B[32m1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mself\u001B[39;49m\u001B[43m.\u001B[49m\u001B[43mshape\u001B[49m\u001B[43m[\u001B[49m\u001B[32;43m0\u001B[39;49m\u001B[43m]\u001B[49m\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n", + "\u001B[31mIndexError\u001B[39m: tuple index out of range", "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[31mTypeError\u001b[39m Traceback (most recent call last)", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f, x, y))\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[43mjit_g\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_a\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_b\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m jit_g_aux(test_a, test_b)\n", - " \u001b[31m[... skipping hidden 13 frame]\u001b[39m\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[14]\u001b[39m\u001b[32m, line 1\u001b[39m, in \u001b[36m\u001b[39m\u001b[34m(x, y)\u001b[39m\n\u001b[32m----> \u001b[39m\u001b[32m1\u001b[39m jit_g = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: \u001b[43mg\u001b[49m\u001b[43m(\u001b[49m\u001b[43mf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[32m 2\u001b[39m jit_g_aux = jit(\u001b[38;5;28;01mlambda\u001b[39;00m x, y: g(f_aux, x, y))\n\u001b[32m 3\u001b[39m jit_g(test_a, test_b)\n", - "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[4]\u001b[39m\u001b[32m, line 10\u001b[39m, in \u001b[36mg\u001b[39m\u001b[34m(func, x, y)\u001b[39m\n\u001b[32m 8\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mg\u001b[39m(func, x, y):\n\u001b[32m 9\u001b[39m func_val = func(x, y)\n\u001b[32m---> \u001b[39m\u001b[32m10\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfunc_val\u001b[49m\u001b[43m)\u001b[49m == \u001b[32m2\u001b[39m:\n\u001b[32m 11\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m func_val[\u001b[32m0\u001b[39m]\n\u001b[32m 12\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n", - " \u001b[31m[... skipping hidden 1 frame]\u001b[39m\n", - "\u001b[36mFile \u001b[39m\u001b[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001b[39m, in \u001b[36mShapedArray._len\u001b[39m\u001b[34m(self, ignored_tracer)\u001b[39m\n\u001b[32m 1896\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m.shape[\u001b[32m0\u001b[39m]\n\u001b[32m 1897\u001b[39m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mIndexError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[32m-> \u001b[39m\u001b[32m1898\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[33m\"\u001b[39m\u001b[33mlen() of unsized object\u001b[39m\u001b[33m\"\u001b[39m) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01merr\u001b[39;00m\n", - "\u001b[31mTypeError\u001b[39m: len() of unsized object" + "\u001B[31mTypeError\u001B[39m Traceback (most recent call last)", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 3\u001B[39m\n\u001B[32m 1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f, x, y))\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m----> \u001B[39m\u001B[32m3\u001B[39m \u001B[43mjit_g\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtest_a\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtest_b\u001B[49m\u001B[43m)\u001B[49m\n\u001B[32m 4\u001B[39m jit_g_aux(test_a, test_b)\n", + " \u001B[31m[... skipping hidden 13 frame]\u001B[39m\n", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[14]\u001B[39m\u001B[32m, line 1\u001B[39m, in \u001B[36m\u001B[39m\u001B[34m(x, y)\u001B[39m\n\u001B[32m----> \u001B[39m\u001B[32m1\u001B[39m jit_g = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: \u001B[43mg\u001B[49m\u001B[43m(\u001B[49m\u001B[43mf\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my\u001B[49m\u001B[43m)\u001B[49m)\n\u001B[32m 2\u001B[39m jit_g_aux = jit(\u001B[38;5;28;01mlambda\u001B[39;00m x, y: g(f_aux, x, y))\n\u001B[32m 3\u001B[39m jit_g(test_a, test_b)\n", + "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[4]\u001B[39m\u001B[32m, line 10\u001B[39m, in \u001B[36mg\u001B[39m\u001B[34m(func, x, y)\u001B[39m\n\u001B[32m 8\u001B[39m \u001B[38;5;28;01mdef\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34mg\u001B[39m(func, x, y):\n\u001B[32m 9\u001B[39m func_val = func(x, y)\n\u001B[32m---> \u001B[39m\u001B[32m10\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28;43mlen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfunc_val\u001B[49m\u001B[43m)\u001B[49m == \u001B[32m2\u001B[39m:\n\u001B[32m 11\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m func_val[\u001B[32m0\u001B[39m]\n\u001B[32m 12\u001B[39m \u001B[38;5;28;01melse\u001B[39;00m:\n", + " \u001B[31m[... skipping hidden 1 frame]\u001B[39m\n", + "\u001B[36mFile \u001B[39m\u001B[32m~/micromamba/envs/dcegm/lib/python3.11/site-packages/jax/_src/core.py:1898\u001B[39m, in \u001B[36mShapedArray._len\u001B[39m\u001B[34m(self, ignored_tracer)\u001B[39m\n\u001B[32m 1896\u001B[39m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m.shape[\u001B[32m0\u001B[39m]\n\u001B[32m 1897\u001B[39m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIndexError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m err:\n\u001B[32m-> \u001B[39m\u001B[32m1898\u001B[39m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[33m\"\u001B[39m\u001B[33mlen() of unsized object\u001B[39m\u001B[33m\"\u001B[39m) \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01merr\u001B[39;00m\n", + "\u001B[31mTypeError\u001B[39m: len() of unsized object" ] } ], diff --git a/tests/test_discrete_versus_continuous_experience.py b/tests/test_discrete_versus_continuous_experience.py index 3587afea..2c65545f 100644 --- a/tests/test_discrete_versus_continuous_experience.py +++ b/tests/test_discrete_versus_continuous_experience.py @@ -138,24 +138,24 @@ def test_replication_discrete_versus_continuous_experience( states_cont["assets_begin_of_period"] = wealth_to_test - value_cont_interp = model_solved_cont.value_for_state_and_choice( - state=states_cont, - choice=choice, + value_cont_interp = model_solved_cont.value_for_states_and_choices( + states=states_cont, + choices=choice, ) - policy_cont_interp = model_solved_cont.policy_for_state_and_choice( - state=states_cont, - choice=choice, + policy_cont_interp = model_solved_cont.policy_for_states_and_choices( + states=states_cont, + choices=choice, ) states_disc["assets_begin_of_period"] = wealth_to_test - value_disc_interp = model_solved_disc.value_for_state_and_choice( - state=states_disc, - choice=choice, + value_disc_interp = model_solved_disc.value_for_states_and_choices( + states=states_disc, + choices=choice, ) - policy_disc_interp = model_solved_disc.policy_for_state_and_choice( - state=states_disc, - choice=choice, + policy_disc_interp = model_solved_disc.policy_for_states_and_choices( + states=states_disc, + choices=choice, ) # policy_cont_interp, = ( diff --git a/tests/test_replication.py b/tests/test_replication.py index f07259f9..ed97715b 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -3,13 +3,13 @@ import jax.numpy as jnp import pytest +from interp1d_auxiliary import ( + linear_interpolation_with_extrapolation, +) from numpy.testing import assert_array_almost_equal as aaae import dcegm import dcegm.toy_models as toy_models -from tests.utils.interp1d_auxiliary import ( - linear_interpolation_with_extrapolation, -) # Obtain the test directory of the package TEST_DIR = Path(__file__).parent @@ -74,16 +74,17 @@ def test_benchmark_models(model_name): policy_expec_interp = linear_interpolation_with_extrapolation( x_new=wealth_grid_to_test, x=policy_expec[0], y=policy_expec[1] ) - + lagged_choice = state_choice_space_to_test[state_choice_idx, 1] state = { - "period": period, - "lagged_choice": state_choice_space_to_test[state_choice_idx, 1], + "period": jnp.ones_like(wealth_grid_to_test, dtype=int) * period, + "lagged_choice": jnp.ones_like(wealth_grid_to_test, dtype=int) + * lagged_choice, "assets_begin_of_period": wealth_grid_to_test, } policy_calc_interp, value_calc_interp = ( - model_solved.value_and_policy_for_state_and_choice( - state=state, - choice=choice, + model_solved.policy_and_value_for_states_and_choices( + states=state, + choices=jnp.ones_like(wealth_grid_to_test, dtype=int) * choice, ) ) diff --git a/tests/test_sparse_stochastic_and_batch_sep.py b/tests/test_sparse_stochastic_and_batch_sep.py index daa47246..ba9707c0 100644 --- a/tests/test_sparse_stochastic_and_batch_sep.py +++ b/tests/test_sparse_stochastic_and_batch_sep.py @@ -127,7 +127,7 @@ def test_benchmark_models(): } (endog_grid_full, policy_full, value_full) = ( model_solved_full.get_solution_for_discrete_state_choice( - states=states_dict, choice=state_choices_sparse[:, -1] + states=states_dict, choices=state_choices_sparse[:, -1] ) ) diff --git a/tests/test_two_period_continuous_experience.py b/tests/test_two_period_continuous_experience.py index 98c7fe23..a081093c 100644 --- a/tests/test_two_period_continuous_experience.py +++ b/tests/test_two_period_continuous_experience.py @@ -9,10 +9,10 @@ import dcegm import dcegm.toy_models as toy_models -from dcegm.backward_induction import create_solution_container from dcegm.final_periods import solve_final_period from dcegm.law_of_motion import calc_cont_grids_next_period from dcegm.numerical_integration import quadrature_legendre +from dcegm.pre_processing.sol_container import create_solution_container from dcegm.solve_single_period import solve_for_interpolated_values MAX_WEALTH = 50 @@ -290,7 +290,7 @@ def create_test_inputs(): endog_grid_solved=endog_grid_solved, ) - endog_grid, policy, value_second_last = solve_for_interpolated_values( + out_dict_second_last = solve_for_interpolated_values( value_interpolated=value_interp_final_period, marginal_utility_interpolated=marginal_utility_final_last_period, state_choice_mat=last_two_period_batch_info_cont[ @@ -308,15 +308,22 @@ def create_test_inputs(): income_shock_weights=income_shock_weights, continuous_grids_info=model_config["continuous_states_info"], model_funcs=model_funcs_cont, + debug_info=None, ) idx_second_last = last_two_period_batch_info_cont[ "idx_state_choices_second_last_period" ] - value_solved = value_solved.at[idx_second_last, ...].set(value_second_last) - policy_solved = policy_solved.at[idx_second_last, ...].set(policy) - endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set(endog_grid) + value_solved = value_solved.at[idx_second_last, ...].set( + out_dict_second_last["value"] + ) + policy_solved = policy_solved.at[idx_second_last, ...].set( + out_dict_second_last["policy"] + ) + endog_grid_solved = endog_grid_solved.at[idx_second_last, ...].set( + out_dict_second_last["endog_grid"] + ) return ( value_solved, @@ -442,8 +449,9 @@ def _get_solve_last_two_periods_args(model, params, has_second_continuous_state) # Create solution containers for value, policy, and endogenous grids value_solved, policy_solved, endog_grid_solved = create_solution_container( - model_structure=model_structure, - model_config=model_config, + continuous_states_info=model_config["continuous_states_info"], + n_total_wealth_grid=model_config["tuning_params"]["n_total_wealth_grid"], + n_state_choices=model_structure["state_choice_space"].shape[0], ) return ( diff --git a/tests/test_utility_second_continuous.py b/tests/test_utility_second_continuous.py index 1e026eb8..705fab75 100644 --- a/tests/test_utility_second_continuous.py +++ b/tests/test_utility_second_continuous.py @@ -388,9 +388,9 @@ def test_replication_discrete_versus_continuous_experience( policy_disc_interp, value_disc_interp = interp1d_policy_and_value_on_wealth( wealth=jnp.array(wealth_to_test), - endog_grid=endog_grid_disc[idx_state_choice_disc], - policy=policy_disc[idx_state_choice_disc], - value=value_disc[idx_state_choice_disc], + wealth_grid=endog_grid_disc[idx_state_choice_disc], + policy_grid=policy_disc[idx_state_choice_disc], + value_grid=value_disc[idx_state_choice_disc], compute_utility=model_disc.model_funcs["compute_utility"], state_choice_vec=state_choice_disc_dict, params=PARAMS, diff --git a/tests/test_varying_shock_scale.py b/tests/test_varying_shock_scale.py index 156f0bf6..0dc8c9fd 100644 --- a/tests/test_varying_shock_scale.py +++ b/tests/test_varying_shock_scale.py @@ -61,18 +61,20 @@ def test_benchmark_models(): policy_expec_interp = linear_interpolation_with_extrapolation( x_new=wealth_grid_to_test, x=policy_expec[0], y=policy_expec[1] ) + lagged_choice = state_choice_space_to_test[state_choice_idx, 1] state = { - "period": period, - "lagged_choice": state_choice_space_to_test[state_choice_idx, 1], + "period": jnp.ones_like(wealth_grid_to_test, dtype=int) * period, + "lagged_choice": jnp.ones_like(wealth_grid_to_test, dtype=int) + * lagged_choice, "assets_begin_of_period": wealth_grid_to_test, } ( policy_calc_interp, value_calc_interp, - ) = model_solved.value_and_policy_for_state_and_choice( - state=state, - choice=choice, + ) = model_solved.policy_and_value_for_states_and_choices( + states=state, + choices=jnp.ones_like(wealth_grid_to_test, dtype=int) * choice, ) aaae(policy_expec_interp, policy_calc_interp)