Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some more tests #16

Merged
merged 5 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-clickhouse.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: chDB tests
name: Clickhouse tests
on:
pull_request:
branches:
Expand Down
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,18 @@ If you require different behaviour (for instance if you have an unusual date for

There is not currently a way in Clickhouse to deal directly with date values before 1900 - if you require such values you will have to manually process these to a different type, and construct the relevant SQL logic.

### `NULL` values in `chDB`

When passing data into `chdb` from pandas or pyarrow tables, `NULL` values in `String` columns are converted into empty strings, instead of remaining `NULL`.

For now this is not handled within the package. You can workaround the issue by wrapping column names in `NULLIF`:

```python
import splink.comparison_level as cl

fn_comparison = cl.DamerauLevenshteinAtThresholds("NULLIF(first_name, '')")
```

### Term-frequency adjustments

Currently at most one term frequency adjustment can be used with `ClickhouseAPI`.
Expand Down
70 changes: 49 additions & 21 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import splink.comparison_library as cl
from chdb import dbapi
from pytest import fixture, mark, param
from splink import SettingsCreator, block_on, splink_datasets
from splink import ColumnExpression, SettingsCreator, block_on, splink_datasets

from splinkclickhouse import ChDBAPI, ClickhouseAPI

Expand Down Expand Up @@ -84,23 +84,51 @@ def fake_1000(version):


@fixture
def fake_1000_settings():
return SettingsCreator(
link_type="dedupe_only",
comparisons=[
cl.JaroWinklerAtThresholds("first_name"),
cl.JaroAtThresholds("surname"),
cl.DateOfBirthComparison(
"dob",
input_is_string=True,
),
cl.DamerauLevenshteinAtThresholds("city").configure(
term_frequency_adjustments=True
),
cl.JaccardAtThresholds("email"),
],
blocking_rules_to_generate_predictions=[
block_on("first_name", "dob"),
block_on("surname"),
],
)
def fake_1000_settings_factory():
def fake_1000_settings(version):
if version == "clickhouse":
return SettingsCreator(
link_type="dedupe_only",
comparisons=[
cl.JaroWinklerAtThresholds("first_name"),
cl.JaroAtThresholds("surname"),
cl.DateOfBirthComparison(
"dob",
input_is_string=True,
),
cl.DamerauLevenshteinAtThresholds("city").configure(
term_frequency_adjustments=True
),
cl.JaccardAtThresholds("email"),
],
blocking_rules_to_generate_predictions=[
block_on("first_name", "dob"),
block_on("surname"),
],
)
# for chdb we wrap all columns in regex_extract, which also includes a nullif
# this circumvents issue where string column NULL values are parsed as empty
# string instead of NULL when we import them into chdb
return SettingsCreator(
link_type="dedupe_only",
comparisons=[
cl.JaroWinklerAtThresholds(
ColumnExpression("first_name").regex_extract(".*")
),
cl.JaroAtThresholds(ColumnExpression("surname").regex_extract(".*")),
cl.DateOfBirthComparison(
ColumnExpression("dob").regex_extract(".*"),
input_is_string=True,
),
cl.DamerauLevenshteinAtThresholds(
ColumnExpression("city").regex_extract(".*")
).configure(term_frequency_adjustments=True),
cl.JaccardAtThresholds(ColumnExpression("email").regex_extract(".*")),
],
blocking_rules_to_generate_predictions=[
block_on("first_name", "dob"),
block_on("surname"),
],
)

return fake_1000_settings
55 changes: 43 additions & 12 deletions tests/test_basic_functionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,35 @@
from splink.exploratory import completeness_chart, profile_columns


def test_make_linker(api_info, fake_1000_factory, fake_1000_settings):
def test_make_linker(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
Linker(df, fake_1000_settings, db_api)


def test_train_u(api_info, fake_1000_factory, fake_1000_settings):
def test_train_u(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)
linker.training.estimate_u_using_random_sampling(max_pairs=3e4)


def test_train_lambda(api_info, fake_1000_factory, fake_1000_settings):
def test_train_lambda(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)
linker.training.estimate_probability_two_random_records_match(
[block_on("dob"), block_on("first_name", "surname")], recall=0.8
)


def test_em_training(api_info, fake_1000_factory, fake_1000_settings):
def test_em_training(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)
linker.training.estimate_parameters_using_expectation_maximisation(
block_on("dob"),
Expand All @@ -39,16 +43,18 @@ def test_em_training(api_info, fake_1000_factory, fake_1000_settings):
)


def test_predict(api_info, fake_1000_factory, fake_1000_settings):
def test_predict(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)
linker.inference.predict()


def test_clustering(api_info, fake_1000_factory, fake_1000_settings):
def test_clustering(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)
df_predict = linker.inference.predict()
linker.clustering.cluster_pairwise_predictions_at_threshold(
Expand All @@ -57,9 +63,12 @@ def test_clustering(api_info, fake_1000_factory, fake_1000_settings):
)


def test_cumulative_comparisons(api_info, fake_1000_factory, fake_1000_settings):
def test_cumulative_comparisons(
api_info, fake_1000_factory, fake_1000_settings_factory
):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])

blocking_rules = fake_1000_settings.blocking_rules_to_generate_predictions

Expand Down Expand Up @@ -89,16 +98,20 @@ def test_completeness(api_info, fake_1000_factory):
completeness_chart(df, db_api=db_api)


def test_match_weights_chart(api_info, fake_1000_factory, fake_1000_settings):
def test_match_weights_chart(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)
linker.visualisations.match_weights_chart()


def test_parameter_estimates_chart(api_info, fake_1000_factory, fake_1000_settings):
def test_parameter_estimates_chart(
api_info, fake_1000_factory, fake_1000_settings_factory
):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)
linker.training.estimate_parameters_using_expectation_maximisation(
block_on("dob"),
Expand All @@ -109,19 +122,36 @@ def test_parameter_estimates_chart(api_info, fake_1000_factory, fake_1000_settin
linker.visualisations.parameter_estimate_comparisons_chart()


def test_m_u_chart(api_info, fake_1000_factory, fake_1000_settings):
def test_m_u_chart(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)

linker.visualisations.m_u_parameters_chart()


def test_unlinkables_chart(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])

linker = Linker(df, fake_1000_settings, db_api)

# db_api.debug_mode = True
linker.evaluation.unlinkables_chart()
# import json
# with open(f"tmp_{api_info['version']}.json", "w+") as f:
# json.dump(ch, f)
# raise TypeError()


def test_comparison_viewer_dashboard(
api_info, fake_1000_factory, fake_1000_settings, tmp_path
api_info, fake_1000_factory, fake_1000_settings_factory, tmp_path
):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
fake_1000_settings.retain_intermediate_calculation_columns = True
linker = Linker(df, fake_1000_settings, db_api)

Expand All @@ -130,10 +160,11 @@ def test_comparison_viewer_dashboard(


def test_cluster_studio_dashboard(
api_info, fake_1000_factory, fake_1000_settings, tmp_path
api_info, fake_1000_factory, fake_1000_settings_factory, tmp_path
):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
fake_1000_settings.retain_intermediate_calculation_columns = True
linker = Linker(df, fake_1000_settings, db_api)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_full_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

# this tests similar steps to test_basic_functionality.py, but alltogether
# this should catch issues we may have in building up cache/other state
def test_full_basic_run(api_info, fake_1000_factory, fake_1000_settings):
def test_full_basic_run(api_info, fake_1000_factory, fake_1000_settings_factory):
db_api = api_info["db_api"]
df = fake_1000_factory(api_info["version"])
fake_1000_settings = fake_1000_settings_factory(api_info["version"])
linker = Linker(df, fake_1000_settings, db_api)

# training
Expand Down