Skip to content

Commit

Permalink
Fix validation on injected inference data with unique_sigma_for_each_geo
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731171886
  • Loading branch information
andyl7an authored and The Meridian Authors committed Feb 28, 2025
1 parent 4521db1 commit 2bfcab8
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 3 deletions.
11 changes: 8 additions & 3 deletions meridian/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,14 @@ def _validate_injected_inference_data_group(
self._validate_injected_inference_data_group_coord(
inference_data, group, constants.TIME, self.n_times
)
self._validate_injected_inference_data_group_coord(
inference_data, group, constants.SIGMA_DIM, self._sigma_shape
)
if self.model_spec.unique_sigma_for_each_geo:
self._validate_injected_inference_data_group_coord(
inference_data, group, constants.GEO, self._sigma_shape
)
else:
self._validate_injected_inference_data_group_coord(
inference_data, group, constants.SIGMA_DIM, self._sigma_shape
)
self._validate_injected_inference_data_group_coord(
inference_data,
group,
Expand Down
149 changes: 149 additions & 0 deletions meridian/model/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1897,6 +1897,155 @@ def test_validate_injected_inference_data_prior_incorrect_coordinates(
inference_data=inference_data,
)

@parameterized.named_parameters(
dict(
testcase_name="sigma_dims_unique_sigma",
coord=constants.GEO,
mismatched_priors={
constants.BETA_GOM: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_ORGANIC_MEDIA_CHANNELS,
),
constants.BETA_GORF: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_ORGANIC_RF_CHANNELS,
),
constants.GAMMA_GN: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_NON_MEDIA_CHANNELS,
),
constants.GAMMA_GC: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_CONTROLS,
),
constants.TAU_G: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_CONTROLS,
),
constants.TAU_G_EXCL_BASELINE: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
),
constants.BETA_GM: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_MEDIA_CHANNELS,
),
constants.BETA_GRF: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_RF_CHANNELS,
),
constants.BETA_GOM_DEV: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_ORGANIC_MEDIA_CHANNELS,
),
constants.BETA_GORF_DEV: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_ORGANIC_RF_CHANNELS,
),
constants.GAMMA_GN_DEV: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_NON_MEDIA_CHANNELS,
),
constants.GAMMA_GC_DEV: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
input_data_samples._N_CONTROLS,
),
constants.SIGMA: (
1,
input_data_samples._N_DRAWS,
input_data_samples._N_GEOS + 1,
),
},
mismatched_coord_size=input_data_samples._N_GEOS + 1,
expected_coord_size=input_data_samples._N_GEOS,
unique_sigma=True,
),
dict(
testcase_name="sigma_dims_not_unique_sigma",
coord=constants.SIGMA_DIM,
mismatched_priors={
constants.SIGMA: (
1,
input_data_samples._N_DRAWS,
2,
),
},
mismatched_coord_size=2,
expected_coord_size=1,
unique_sigma=False,
),
)
def test_validate_injected_inference_data_prior_incorrect_sigma_coordinates(
self,
coord,
mismatched_priors,
mismatched_coord_size,
expected_coord_size,
unique_sigma,
):
"""Checks validation fails with incorrect coordinates for sigma."""
model_spec = spec.ModelSpec(unique_sigma_for_each_geo=unique_sigma)
meridian = model.Meridian(
input_data=self.input_data_non_media_and_organic,
model_spec=model_spec,
)
prior_samples = meridian.prior_sampler_callable._sample_prior(self._N_DRAWS)
prior_coords = meridian.create_inference_data_coords(1, self._N_DRAWS)
prior_dims = meridian.create_inference_data_dims()

prior_samples = dict(prior_samples)
for param in mismatched_priors:
prior_samples[param] = tf.zeros(mismatched_priors[param])
prior_coords = dict(prior_coords)
prior_coords[coord] = np.arange(mismatched_coord_size)
if unique_sigma:
prior_coords[constants.GEO] = np.arange(mismatched_coord_size)
else:
prior_coords[constants.SIGMA_DIM] = np.arange(mismatched_coord_size)

inference_data = az.convert_to_inference_data(
prior_samples,
coords=prior_coords,
dims=prior_dims,
group=constants.PRIOR,
)

with self.assertRaisesRegex(
ValueError,
"Injected inference data prior has incorrect coordinate"
f" '{coord}': expected"
f" {expected_coord_size}, got"
f" {mismatched_coord_size}",
):
_ = model.Meridian(
input_data=self.input_data_non_media_and_organic,
model_spec=model_spec,
inference_data=inference_data,
)


if __name__ == "__main__":
absltest.main()

0 comments on commit 2bfcab8

Please sign in to comment.