Skip to content

Commit

Permalink
create a shadow version of comparison (level)
Browse files Browse the repository at this point in the history
we cannot directly use dialect properties to adjust the SQL, so we need to make our own version
  • Loading branch information
ADBond committed Dec 12, 2024
1 parent bf9b80c commit 4c05851
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
44 changes: 44 additions & 0 deletions splinkclickhouse/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from splink.internals.comparison_level_library import (
DateMetricType,
)
from splink.internals.comparison_level_library import (
PairwiseStringDistanceFunctionLevel as SplinkPairwiseStringDistanceFunctionLevel,
)

from .column_expression import ColumnExpression as CHColumnExpression
from .dialect import ClickhouseDialect, SplinkDialect
Expand Down Expand Up @@ -164,3 +167,44 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str:
f"<= {self.time_threshold_seconds}"
)
return sql


class PairwiseStringDistanceFunctionLevel(SplinkPairwiseStringDistanceFunctionLevel):
def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
distance_function_name_transpiled = {
"levenshtein": sql_dialect.levenshtein_function_name,
"damerau_levenshtein": sql_dialect.damerau_levenshtein_function_name,
"jaro_winkler": sql_dialect.jaro_winkler_function_name,
"jaro": sql_dialect.jaro_function_name,
}[self.distance_function_name]

aggregator_func = {
"min": sql_dialect.array_min_function_name,
"max": sql_dialect.array_max_function_name,
}[self._aggregator()]

# order of the arguments is different in Clickhouse than tha expected by Splink
# specifically the lambda must come first in Clickhouse
# this is not fixable with UDF as having it in second argument in general
# will cause Clickhouse parser to fail
# also need to use a workaround to get 'flatten' equivalent for a single level
return f"""{aggregator_func}(
{sql_dialect.array_transform_function_name}(
pair -> {distance_function_name_transpiled}(
pair[{sql_dialect.array_first_index}],
pair[{sql_dialect.array_first_index + 1}]
),
arrayReduce(
'array_concat_agg',
{sql_dialect.array_transform_function_name}(
x -> {sql_dialect.array_transform_function_name}(
y -> [x, y],
{col.name_r}
),
{col.name_l}
)
)
)
) {self._comparator()} {self.distance_threshold}"""
24 changes: 24 additions & 0 deletions splinkclickhouse/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from splink.internals.comparison_library import (
DateOfBirthComparison as SplinkDateOfBirthComparison,
)
from splink.internals.comparison_library import (
PairwiseStringDistanceFunctionAtThresholds as SplinkPairwiseStringDistanceFunctionAtThresholds, # noqa: E501 (can't keep format and check happy)
)
from splink.internals.misc import ensure_is_iterable

import splinkclickhouse.comparison_level_library as cll_ch
Expand Down Expand Up @@ -305,3 +308,24 @@ def create_comparison_levels(self) -> list[ComparisonLevelCreator]:

levels.append(cll.ElseLevel())
return levels


class PairwiseStringDistanceFunctionAtThresholds(
SplinkPairwiseStringDistanceFunctionAtThresholds
):
def create_comparison_levels(self) -> list[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
# It is assumed that any string distance treats identical
# arrays as the most similar
cll.ArrayIntersectLevel(self.col_expression, min_intersection=1),
*[
cll_ch.PairwiseStringDistanceFunctionLevel(
self.col_expression,
distance_threshold=threshold,
distance_function_name=self.distance_function_name,
)
for threshold in self.thresholds
],
cll.ElseLevel(),
]

0 comments on commit 4c05851

Please sign in to comment.