Skip to content

Commit

Permalink
Merge pull request #8 from transferwise/linting
Browse files Browse the repository at this point in the history
Linting
  • Loading branch information
julianteichgraber authored Jan 11, 2024
2 parents 6ed1f56 + c4478b2 commit c04c4ed
Show file tree
Hide file tree
Showing 26 changed files with 494 additions and 742 deletions.
25 changes: 24 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tw-experimentation"
version = "0.1.1.13"
version = "0.1.2"
description = "Wise AB platform"
authors = ["Wise"]
readme = "README.md"
Expand Down Expand Up @@ -56,3 +56,26 @@ optional = true
causaltune = "^0.1.3"


[tool.black]
line-length = 88
target_version = ["py39"]
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.mypy_cache
| \.tox
| \.venv
| _build
| build
| dist
)/
)
'''

[tool.isort]
profile = "black"
line_length = 88
4 changes: 2 additions & 2 deletions tw_experimentation/bayes/bayes_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def set_model(
assert self._model_is_well_defined()

def set_prior_model(self, variable: str, model, model_params: dict):
"""Set prior model for a likelihood model parameter
"""Set prior model for a likelihood model parameter.
Args:
variable (str): name of variab model variable to be fed into likelihood
Expand All @@ -87,7 +87,7 @@ def set_prior_model(self, variable: str, model, model_params: dict):
self.prior_model_params[variable] = model_params

def set_prior_model_param(self, variable: str, model_params: dict):
"""Set parameters for prior
"""Set parameters for prior.
Args:
variable (str): name of variab model variable to be fed into likelihood
Expand Down
74 changes: 31 additions & 43 deletions tw_experimentation/bayes/bayes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@

@dataclass
class BayesResult:
"""
Class to store the results of a Bayesian test
"""
"""Class to store the results of a Bayesian test."""

targets: List[str]
metric_types: List[str]
Expand Down Expand Up @@ -90,8 +88,8 @@ def bayes_factor_decision(
return "reject null"

def prob_greater_than_zero(self, target: str):
"""
Compute the probability that the average treatment effect is greater than zero
"""Compute the probability that the average treatment effect is greater than
zero.
Args:
target (str): target metric
Expand All @@ -103,8 +101,7 @@ def prob_greater_than_zero(self, target: str):
}

def prob_greater_than_z(self, z: float, target: str):
"""
Compute the probability that the average treatment effect is greater than z
"""Compute the probability that the average treatment effect is greater than z.
Args:
z (float): threshold
Expand All @@ -117,8 +114,7 @@ def prob_greater_than_z(self, z: float, target: str):
}

def prob_smaller_than_z(self, z: float, target: str):
"""
Compute the probability that the average treatment effect is smaller than z
"""Compute the probability that the average treatment effect is smaller than z.
Args:
z (float): threshold
Expand All @@ -131,9 +127,8 @@ def prob_smaller_than_z(self, z: float, target: str):
}

def prob_greater_than_z_absolute(self, z: float, target: str):
"""
Compute the probability that the absolute value of
the average treatment effect is greater than z
"""Compute the probability that the absolute value of the average treatment
effect is greater than z.
Args:
z (float): threshold
Expand All @@ -147,8 +142,8 @@ def prob_greater_than_z_absolute(self, z: float, target: str):
}

def prob_within_interval(self, z_lower: float, z_upper: float, target: str):
"""
Compute the probability that the average treatment effect is within the interval [z_lower, z_upper]
"""Compute the probability that the average treatment effect is within the
interval [z_lower, z_upper]
Args:
z_lower (float): lower bound of interval
Expand All @@ -163,8 +158,8 @@ def prob_within_interval(self, z_lower: float, z_upper: float, target: str):
}

def prob_outside_interval(self, z_lower: float, z_upper: float, target: str):
"""
Compute the probability that the average treatment effect is outside the interval [z_lower, z_upper]
"""Compute the probability that the average treatment effect is outside the
interval [z_lower, z_upper]
Args:
z_lower (float): lower bound of interval
Expand All @@ -184,9 +179,8 @@ def rope(
rope_upper: Optional[float] = None,
rope_lower: Optional[float] = None,
):
"""
Compute the probability that the average treatment effect
is in the region of practical equivalence (ROPE)
"""Compute the probability that the average treatment effect is in the region of
practical equivalence (ROPE)
https://easystats.github.io/bayestestR/articles/region_of_practical_equivalence.html
Expand Down Expand Up @@ -215,7 +209,8 @@ def rope(
def _rope_interval_autodetect_intervals(
self, target: str, scale_param: Optional[float] = 0.1
):
"""Compute the ROPE interval based on the standard deviation of the target metric
"""Compute the ROPE interval based on the standard deviation of the target
metric.
Args:
target (str): target metric
Expand All @@ -232,8 +227,7 @@ def _rope_interval_autodetect_intervals(
def _posterior_and_hdi_plot(
self, sample_per_variant, posterior_hdi_per_variant, distribution_opacity=0.3
):
"""
Plot the posterior distribution and the high density interval (HDI)
"""Plot the posterior distribution and the high density interval (HDI)
Args:
sample_per_variant (dict): dictionary of posterior samples
Expand Down Expand Up @@ -306,8 +300,8 @@ def _posterior_and_hdi_plot(
return fig

def fig_posterior_by_target(self, target: str, distribution_opacity: float = 0.3):
"""
Plot the posterior distribution and the high density interval (HDI) of the expected value
"""Plot the posterior distribution and the high density interval (HDI) of the
expected value.
Args:
target (str): target metric
Expand All @@ -331,8 +325,8 @@ def fig_posterior_by_target(self, target: str, distribution_opacity: float = 0.3
def fig_posterior_cdf_by_target(
self, target: str, distribution_opacity: float = 0.3, facet_rows_variant=False
):
"""
Generates a plot of the empirical cumulative distribution (ECDF) function of treatment effect for a given target.
"""Generates a plot of the empirical cumulative distribution (ECDF) function of
treatment effect for a given target.
Args:
target (str): The target for which to generate the plot.
Expand Down Expand Up @@ -360,8 +354,8 @@ def fig_posterior_cdf_by_target(
def fig_posterior_difference_by_target(
self, target: str, distribution_opacity: float = 0.3
):
"""
Plot the posterior distribution and the high density interval (HDI) of the expected treatment effect
"""Plot the posterior distribution and the high density interval (HDI) of the
expected treatment effect.
Args:
target (str): target metric
Expand Down Expand Up @@ -390,10 +384,9 @@ def fig_posterior_difference_cdf(
shade_areas: bool = True,
shade_limits: Tuple[Union[float, None], Union[float, None]] = (None, None),
) -> make_subplots:
"""
Generates a plotly figure showing the cumulative density function of the treatment effect
for each variant, based on the posterior distribution of the difference in means between
the variant and the control group.
"""Generates a plotly figure showing the cumulative density function of the
treatment effect for each variant, based on the posterior distribution of the
difference in means between the variant and the control group.
Args:
sample_per_variant (dict): A dictionary mapping variant names to lists of samples.
Expand Down Expand Up @@ -498,7 +491,7 @@ def set_model(
# self.set_prior_model(*model_and_params)

def set_prior_model(self, target, variable: str, model, model_params: dict):
"""Set prior model for a likelihood model parameter
"""Set prior model for a likelihood model parameter.
Args:
variable (str): name of variab model variable to be fed into likelihood
Expand All @@ -516,8 +509,7 @@ def set_prior_model(self, target, variable: str, model, model_params: dict):
self.update_prior_model_param(target, variable, model_params)

def update_prior_model_param(self, target, variable: str, model_params: dict):
"""Update parameters for prior.
The prior model must have been defined before.
"""Update parameters for prior. The prior model must have been defined before.
Args:
variable (str): name of variab model variable to be fed into likelihood
Expand All @@ -533,7 +525,7 @@ def update_prior_model_param(self, target, variable: str, model_params: dict):
self.params_models_per_target[target][variable] = model_params

def set_model_to_default(self):
"""Reset the bayesian model to default settings"""
"""Reset the bayesian model to default settings."""
self.likelihood_model_per_target = {}
self.variables_per_target = {}
self.prior_models_per_target = {}
Expand Down Expand Up @@ -566,9 +558,7 @@ def _setup_bayesmodel(self, target, fit_model=True):
return bm

def compute_posterior(self, store_prior=True, compute_bayes_factor=True, verbose=0):
"""Run the Bayesian model via numpyro to obtain the
posterior distribution
"""
"""Run the Bayesian model via numpyro to obtain the posterior distribution."""

# TODO: save priors on this level

Expand Down Expand Up @@ -645,10 +635,8 @@ def compute_bayes_factor(self):
self._store_prior_means(target, mcmc)

def _compute_posterior_predictive(self):
"""
compute posterior predictive distribution from posterior samples
only possible after model fit
"""
"""Compute posterior predictive distribution from posterior samples only
possible after model fit."""
N_SAMPLES_POST_PRED = 100000
for target in self.ed.targets:
self.post_pred[target] = {}
Expand Down
8 changes: 4 additions & 4 deletions tw_experimentation/bayes/numpyro_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@


class ZeroInflatedProbsPatch(Distribution):
"""
ZeroInflatedProbs distribution from Numpyro
"""ZeroInflatedProbs distribution from Numpyro.
https://num.pyro.ai/en/stable/_modules/numpyro/distributions/discrete.html#ZeroInflatedDistribution
Expand All @@ -42,8 +41,9 @@ def __init__(self, base_dist, gate, *, validate_args=None):
# assert base_dist.support.is_discrete
if base_dist.event_shape:
raise ValueError(
"ZeroInflatedProbs expected empty base_dist.event_shape but got {}"
.format(base_dist.event_shape)
"ZeroInflatedProbs expected empty base_dist.event_shape but got {}".format(
base_dist.event_shape
)
)
# XXX: we might need to promote parameters of base_dist but let's keep
# this simplified for now
Expand Down
Loading

0 comments on commit c04c4ed

Please sign in to comment.