Skip to content

Improve subsampling and dataset downloads #2323

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- bump: patch
changes:
fixed:
- US model always downlaods from HuggingFace.
- Subsampling improvements.
59 changes: 50 additions & 9 deletions policyengine_api/jobs/calculate_economy_simulation_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from policyengine_api.endpoints.economy.compare import compare_economic_outputs
from policyengine_api.endpoints.economy.reform_impact import set_comment_on_job
from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS
from policyengine_api.country import COUNTRIES, create_policy_reform
from policyengine_api.country import (
COUNTRIES,
create_policy_reform,
PolicyEngineCountry,
)
from policyengine_core.simulations import Microsimulation
from policyengine_core.tools.hugging_face import download_huggingface_dataset
import h5py
Expand Down Expand Up @@ -259,15 +263,13 @@
options.get("max_households", os.environ.get("MAX_HOUSEHOLDS"))
is not None
):
simulation.subsample(
int(
options.get(
"max_households",
os.environ.get("MAX_HOUSEHOLDS", 1_000_000),
)
),
seed=(region, time_period),
simulation = subsample(

Check warning on line 266 in policyengine_api/jobs/calculate_economy_simulation_job.py

View check run for this annotation

Codecov / codecov/patch

policyengine_api/jobs/calculate_economy_simulation_job.py#L266

Added line #L266 was not covered by tests
options=options,
simulation=simulation,
region=region,
time_period=time_period,
reform=reform,
country=country,
)
simulation.default_calculation_period = time_period

Expand Down Expand Up @@ -463,6 +465,45 @@
}


def subsample(
options: dict,
simulation: Microsimulation,
region: str,
time_period: str,
reform: dict,
country: PolicyEngineCountry,
) -> Microsimulation:
"""
Subsamples a microsimulation dataset and reinitializes the simulation with the subsampled data.
Args:
options (dict): A dictionary of options, which may include "max_households" to specify the maximum number of households to subsample.
simulation (Microsimulation): The original microsimulation object to be subsampled.
region (str): The region for which the simulation is being run.
time_period (str): The time period for which the simulation is being run.
reform (dict): A dictionary representing the policy reform to apply to the simulation.
country (PolicyEngineCountry): The country-specific policy engine object.
Returns:
Microsimulation: A new microsimulation object initialized with the subsampled data and the specified reform.
"""

simulation.subsample(

Check warning on line 489 in policyengine_api/jobs/calculate_economy_simulation_job.py

View check run for this annotation

Codecov / codecov/patch

policyengine_api/jobs/calculate_economy_simulation_job.py#L489

Added line #L489 was not covered by tests
int(
options.get(
"max_households",
os.environ.get("MAX_HOUSEHOLDS", 1_000_000),
)
),
seed=(region, time_period),
time_period=time_period,
)
input_data = simulation.to_input_dataframe()
simulation = country.country_package.Microsimulation(

Check warning on line 500 in policyengine_api/jobs/calculate_economy_simulation_job.py

View check run for this annotation

Codecov / codecov/patch

policyengine_api/jobs/calculate_economy_simulation_job.py#L499-L500

Added lines #L499 - L500 were not covered by tests
dataset=input_data,
reform=reform,
)
return simulation

Check warning on line 504 in policyengine_api/jobs/calculate_economy_simulation_job.py

View check run for this annotation

Codecov / codecov/patch

policyengine_api/jobs/calculate_economy_simulation_job.py#L504

Added line #L504 was not covered by tests


def is_similar(x, y, parent_name: str = "") -> bool:
if x is None or x == {}:
if y is None or y == {}:
Expand Down
Loading