Skip to content

Commit

Permalink
Some ideas for simulation refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Dec 14, 2023
1 parent c938d2d commit 7f300e1
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 8 deletions.
13 changes: 9 additions & 4 deletions src/lcm/entry_point_updated.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def get_lcm_function(
model_specification,
target,
targets,
):
# Setup
# ==================================================================================
Expand Down Expand Up @@ -38,9 +38,13 @@ def get_lcm_function(
# Functions that simulate the agent's choices
# ==================================================================================
argsolve_continuous_problem = [
model.get_argsolve_continuous_problem(t, on="state_choice_space")
model.get_argsolve_continuous_problem(t, on="sim_state_choice_space")
for t in model.periods
]

argsolve_discrete_problem = [
model.get_argsolve_discrete_problem(t) for t in model.periods
]

draw_next_states = [
model.get_draw_next_state(t, on="state_choice") for t in model.periods
Expand All @@ -59,9 +63,10 @@ def get_lcm_function(

_simulate_model = partial(
forward_simulation,
argsolve_continuous_problem=argsolve_continuous_problem,
argsolve_discrete_problem=argsolve_discrete_problem,
state_indexers=state_indexers,
continuous_choice_grids=continuous_choice_grids,
compute_ccv_policy_functions=argsolve_continuous_problem,
model=model._specification,
next_state=draw_next_states,
)
Expand All @@ -73,4 +78,4 @@ def get_lcm_function(
"simulate": _simulate_model,
"solve_and_simulate": partial(_simulate_model, solve_model=_solve_model),
}
return targets[target]
return targets[targets]
4 changes: 2 additions & 2 deletions src/lcm/model_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def get_solve_discrete_problem(self, period: int):
"""Reference: get_emax_calculator"""
choice_segment = _get_choice_segments(...)

def get_solve_discrete_problem(self, period: int):
"""Reference: get_discrete_policy_calculator"""
def get_argsolve_discrete_problem(self, period: int):
"""Reference: get_discrete_policy_calculator, but depends on choice_segments."""
choice_segment = _get_choice_segments(...)

def get_continuous_choice_grids(self, period: int):
Expand Down
170 changes: 168 additions & 2 deletions src/lcm/simulate_updated.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,168 @@
def forward_simulation():
pass
def forward_simulation(
params,
vf_arr_list,
argsolve_continuous_problem,
argsolve_discrete_problem,
state_indexers,
continuous_choice_grids,
model,
next_state,
initial_states,
logger,
additional_targets=None,
seed=12345,
):
"""Simulate the model forward in time.
Goal:
for t in periods:
sim_scs, sim_choice_segments = create_sim_state_choice_space(sim_states)
cont_policies, cont_solution = argsolve_cont_problem(scs, vf_arr, params)
discrete_policies = argsolve_discrete_problem(cont_solution, sim_choice_segments)
sim_states, sim_key = next_states(sim_states, cont_policies, discrete_policies, sim_key)
"""
# Update the vf_arr_list
# ----------------------------------------------------------------------------------
# We drop the value function array for the first period, because it is not needed
# for the simulation. This is because in the first period the agents only consider
# the current utility and the value function of next period. Similarly, the last
# value function array is not required, as the agents only consider the current
# utility in the last period.
# ==================================================================================
vf_arr_list = vf_arr_list[1:] + [None]

# Preparations
# ==================================================================================
n_periods = len(vf_arr_list)
n_initial_states = len(next(iter(initial_states.values())))

sparse_choice_variables = model.variable_info.query("is_choice & is_sparse").index

# The following variables are updated during the forward simulation
states = initial_states
key = jax.random.PRNGKey(seed=seed)

# Forward simulation
# ==================================================================================
_simulation_results = []

for period in range(n_periods):
# Create data state choice space
# ------------------------------------------------------------------------------
# Initial states are treated as sparse variables, so that the sparse variables
# in the data-state-choice-space correspond to the feasible product of sparse
# choice variables and initial states. The space has to be created in each
# iteration because the states change over time.
# ==============================================================================
data_scs, data_choice_segments = create_data_scs(
states=states,
model=model,
)

# Compute objects dependent on data-state-choice-space
# ==============================================================================
dense_vars_grid_shape = tuple(
len(grid) for grid in data_scs.dense_vars.values()
)
cont_choice_grid_shape = tuple(
len(grid) for grid in continuous_choice_grids[period].values()
)

# Compute optimal continuous choice conditional on discrete choices
# ==============================================================================
ccv_policy, ccv = solve_continuous_problem(
data_scs=data_scs,
compute_ccv=argsolve_continuous_problem[period],
continuous_choice_grids=continuous_choice_grids[period],
vf_arr=vf_arr_list[period],
state_indexers=state_indexers[period],
params=params,
)

# Get optimal discrete choice given the optimal conditional continuous choices
# ==============================================================================
dense_argmax, sparse_argmax, value = argsolve_discrete_problem(
conditional_continuation_value=ccv,
choice_segments=data_choice_segments,
)

# Select optimal continuous choice corresponding to optimal discrete choice
# ------------------------------------------------------------------------------
# The conditional continuous choice argmax is computed for each discrete choice
# in the data-state-choice-space. Here we select the the optimal continuous
# choice corresponding to the optimal discrete choice (dense and sparse).
# ==============================================================================
cont_choice_argmax = filter_ccv_policy(
ccv_policy,
dense_argmax=dense_argmax,
dense_vars_grid_shape=dense_vars_grid_shape,
)
if sparse_argmax is not None:
cont_choice_argmax = cont_choice_argmax[sparse_argmax]

# Convert optimal choice indices to actual choice values
# ==============================================================================
dense_choices = retrieve_non_sparse_choices(
indices=dense_argmax,
grids=data_scs.dense_vars,
grid_shape=dense_vars_grid_shape,
)

cont_choices = retrieve_non_sparse_choices(
indices=cont_choice_argmax,
grids=continuous_choice_grids[period],
grid_shape=cont_choice_grid_shape,
)

sparse_choices = {
key: data_scs.sparse_vars[key][sparse_argmax]
for key in sparse_choice_variables
}

# Store results
# ==============================================================================
choices = {**dense_choices, **sparse_choices, **cont_choices}

_simulation_results.append(
{
"value": value,
"choices": choices,
"states": states,
},
)

# Update states
# ==============================================================================
key, sim_keys = _generate_simulation_keys(
key=key,
ids=model.function_info.query("is_stochastic_next").index,
)

states = next_state(
**states,
**choices,
_period=jnp.repeat(period, n_initial_states),
params=params,
keys=sim_keys,
)

# 'next_' prefix is added by the next_state function, but needs to be removed
# because in the next period, next states are current states.
states = {k.removeprefix("next_"): v for k, v in states.items()}

logger.info("Period: %s", period)

processed = _process_simulated_data(_simulation_results)

if additional_targets is not None:
calculated_targets = _compute_targets(
processed,
targets=additional_targets,
model_functions=model.functions,
params=params,
)
processed = {**processed, **calculated_targets}

return _as_data_frame(processed, n_periods=n_periods)

0 comments on commit 7f300e1

Please sign in to comment.