diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..179b568c 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,5 @@ +- bump: patch + changes: + fixed: + - US model always downlaods from HuggingFace. + - Subsampling improvements. diff --git a/policyengine_api/jobs/calculate_economy_simulation_job.py b/policyengine_api/jobs/calculate_economy_simulation_job.py index c976db39..f2923a2c 100644 --- a/policyengine_api/jobs/calculate_economy_simulation_job.py +++ b/policyengine_api/jobs/calculate_economy_simulation_job.py @@ -20,7 +20,11 @@ from policyengine_api.endpoints.economy.compare import compare_economic_outputs from policyengine_api.endpoints.economy.reform_impact import set_comment_on_job from policyengine_api.constants import COUNTRY_PACKAGE_VERSIONS -from policyengine_api.country import COUNTRIES, create_policy_reform +from policyengine_api.country import ( + COUNTRIES, + create_policy_reform, + PolicyEngineCountry, +) from policyengine_api.utils.v2_v1_comparison import ( V2V1Comparison, compute_difference, @@ -371,15 +375,13 @@ def _compute_economy( options.get("max_households", os.environ.get("MAX_HOUSEHOLDS")) is not None ): - simulation.subsample( - int( - options.get( - "max_households", - os.environ.get("MAX_HOUSEHOLDS", 1_000_000), - ) - ), - seed=(region, time_period), + simulation = subsample( + options=options, + simulation=simulation, + region=region, time_period=time_period, + reform=reform, + country=country, ) simulation.default_calculation_period = time_period @@ -575,6 +577,44 @@ def _compute_cliff_impacts(self, simulation: Microsimulation) -> Dict: } +def subsample( + options: dict, + simulation: Microsimulation, + region: str, + time_period: str, + reform: dict, + country: PolicyEngineCountry, +) -> Microsimulation: + """ + Subsamples a microsimulation dataset and reinitializes the simulation with the subsampled data. + Args: + options (dict): A dictionary of options, which may include "max_households" to specify the maximum number of households to subsample. + simulation (Microsimulation): The original microsimulation object to be subsampled. + region (str): The region for which the simulation is being run. + time_period (str): The time period for which the simulation is being run. + reform (dict): A dictionary representing the policy reform to apply to the simulation. + country (PolicyEngineCountry): The country-specific policy engine object. + Returns: + Microsimulation: A new microsimulation object initialized with the subsampled data and the specified reform. + """ + + simulation.subsample( + int( + options.get( + "max_households", + os.environ.get("MAX_HOUSEHOLDS", 1_000_000), + ) + ), + seed=(region, time_period), + time_period=time_period, + ) + input_data = simulation.to_input_dataframe() + simulation = country.country_package.Microsimulation( + dataset=input_data, + reform=reform, + ) + return simulation + class SimulationAPIv2: project: str location: str