Skip to content

Commit

Permalink
Merge pull request #16 from lorentzenchr/add_case_weights_to_compute_…
Browse files Browse the repository at this point in the history
…bias

Add case weights to compute_bias and plot_bias
  • Loading branch information
lorentzenchr authored Feb 13, 2023
2 parents 3203843 + fedc404 commit c88dc4c
Show file tree
Hide file tree
Showing 5 changed files with 368 additions and 214 deletions.
15 changes: 12 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ classifiers = [
]
keywords = ["machine learning", "model diagnostics", "calibration"]
dependencies = [
"pandas >= 1.5",
"pyarrow >= 9.0.0",
"polars >= 0.16.2",
"matplotlib >= 3.6.1",
"scikit-learn >= 1.0",
]
Expand Down Expand Up @@ -78,6 +77,8 @@ exclude = [

[tool.hatch.envs.default]
dependencies = [
"pandas>=1.5",
"pyarrow>=11.0.0",
"pytest",
"pytest-cov",
"pytest-xdist",
Expand Down Expand Up @@ -131,7 +132,15 @@ dependencies = [
[tool.hatch.envs.lint.scripts]
# typing = "mypy --install-types --non-interactive {args:backend/src/hatchling src/hatch tests}"
typing = "mypy --install-types --non-interactive {args:src/model_diagnostics}"
security = "bandit --quiet --recursive --skip B101,B102,B105,B110,B112,B301,B307,B324,B403,B404,B603,B604,B606,B607 {args:.}"
# Note: bandit does not respect the --exclude flag, see
# https://github.com/PyCQA/bandit/issues/657
# --exclude .svn,CVS,.bzr,.hg,.git,*__pycache__,.tox,.eggs,*.egg,.github,.hatch \
# We instead use the path argument instead of {args:.}.
security = """\
bandit --quiet --recursive \
--skip B101,B102,B105,B110,B112,B301,B307,B324,B403,B404,B603,B604,B606,B607 \
{args:src}
"""
style = [
"flake8 --exclude .hatch {args:.}",
"black --check --diff {args:.}",
Expand Down
2 changes: 1 addition & 1 deletion src/model_diagnostics/_utils/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def validate_2_arrays(
a: npt.ArrayLike, b: npt.ArrayLike
) -> tuple[np.ndarray, np.ndarray]:
"""Validate 2 arrays."""
# Note: If the input is an pyarrow array, np.asarray produces a read-only ndarray.
# Note: If the input is a pyarrow array, np.asarray produces a read-only ndarray.
a = np.asarray(a)
b = np.asarray(b)
if a.ndim != b.ndim:
Expand Down
Loading

0 comments on commit c88dc4c

Please sign in to comment.