Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JSON-serialization of Blocking Rules #62

Open
jstammers opened this issue Sep 26, 2024 · 2 comments
Open

JSON-serialization of Blocking Rules #62

jstammers opened this issue Sep 26, 2024 · 2 comments

Comments

@jstammers
Copy link
Contributor

I have a workflow for detecting duplicates in a dataset that I am looking to deploy.

I'd like to be able to configure the blocking rules used to generate candidate pairs as part of the CLI that runs my workflow.

I can add parameters to do this, e.g. --key-blockers="foo,bar" --coordinate_distance_km=0.01, but it feels somewhat fragile to me.

It would be useful if it were possible to reconstruct a blocker from a JSON-representation, e.g.

{ 
    "type": "Key",
    "parameters": { 
        "key":  ["foo", "bar"],
        "name": "foo and bar"
    }
} # de-serializes to KeyBlocker(key=("foo", "bar"), name="foo and bar")

This could be handled by implementing .to_dict() and .from_dict()methods for the blocker classes that exist although I'm not sure how this would deal withCallableorDeferred` objects

@NickCrews
Copy link
Owner

Can you implement this externally, without needing to modify the class? eg

def to_dict(blocker: KeyBlocker) -> dict: ...
def from_dict(d: dict) -> KeyBlocker: ...

I would like this to be possible. If not, let me know what you get stuck with.

[Not looking at code right now, but] since KeyBlocker can accept arbitrary expressions, eg _.ngrams.unnest(), I don't think we can make this ser/deser work in general. Since we can't make it work in general, I'm hesitant to add it builtin to the class. But I could be convinced if you explore it and it is pretty well-defined where the functionality is supported and where it is not.

Also, in general would love to see your script if you can share, always looking for uses, searching for common patterns or pain points we could pull into the core.

@jstammers
Copy link
Contributor Author

For my particular use-case, I've been able to get it working using

def to_dict(blocker: UnionBlocker) -> dict: ...
def from_dict(d: dict) -> UnionBlocker: ...

functions, where the blocking arguments are built-in python objects (strings and lists/tuples of strings) as I haven't needed to ser/deser more complex arguments yet. I agree this might be tricky, particularly if we want to avoid some risky practices (e.g. parsing "_.ngrams.unnest()" literally as _.ngrams.unnest() using `ast.literal_eval)

As for my script, most of what I've done besides some initial data-wrangling and model evaluation is to implement that aims to encapsulate a deduping/matching problem into a class with an sklearn fit/predict API. The main reason for this is that we use mlflow for model tracking and deployment, so implementing something that has fit and predict methods reduces the amount of boilerplate needed elsewhere.

Here`s my current implementation, it's by no means production-level code yet, but hopefully should give some indication of the additional functionality I've needed

from __future__ import annotations
from typing import Literal
from ibis import _
import ibis
import ibis.expr.types as ir
import numpy as np
from mismo._recipe import PRecipe
from mismo.block import UnionBlocker
from mismo.cluster import connected_components
from sklearn.metrics import precision_recall_curve, roc_curve
from mismo.fs import train_using_labels
from .uitls  import from_dict


class EntityLinker:
    def __init__(
        self,
        dimensions: list[PRecipe],
        threshold_method: "str" = "pr",
        task: Literal["dedupe", "link"] = "dedupe",
        blocking_config: dict = None,
    ):
        self.dimensions = dimensions
        self.weights = None
        if threshold_method not in ["pr", "roc"]:
            raise ValueError("threshold_method must be one of 'pr' or 'roc'")
        self._threshold_method = threshold_method
        self.odds_threshold = None
        self._blocking_config = blocking_config
        self._task = task
        self.set_blocker(blocking_config)

    def set_blocker(self, blocker: UnionBlocker | dict | None):
        """Sets the blocker for the model."""
        if isinstance(blocker, dict):
            self._blocker = from_dict(blocker)
        elif isinstance(blocker, UnionBlocker):
            self._blocker = blocker
        else:
            self._blocker = UnionBlocker(*[dim.block for dim in self.dimensions])

    def prepare(self, table: ir.Table) -> ir.Table:
        """Prepare the data for blocking and comparison."""
        if "record_id" not in table.columns:
            raise ValueError("Table must have a 'record_id' column")
        for dim in self.dimensions:
            table = dim.prepare(table)
        return table

    def block(self, left: ir.Table, right: ir.Table) -> ir.Table:
        """Block tables into record pairs."""
        blocked = self._blocker(left, right, task=self._task)
        return blocked

    def compare(self, table: ir.Table) -> ir.Table:
        """Perform comparisons on the blocked pairs."""
        compared = table
        for dim in self.dimensions:
            compared = dim.compare(compared)
        return compared

    def score(self, left: ir.Table, right: ir.Table = None) -> ir.Table:
        """Score pairs of records using the model.

        After calculating the odds of a match, this is converted to a probability, as well as a label based on the odds threshold.

        If the tables have a 'label_true' column, the actual label is also calculated.
        """
        assert self.fitted, "Model must be fitted before scoring"
        self._validate_ids(left, right)
        self._validate_true(left, right)

        if not self._is_prepared(left):
            left = self.prepare(left).cache()
        if right is not None and not self._is_prepared(right):
            right = self.prepare(right).cache()
        else:
            # This occurs when deduping a single table
            right = left
        blocked = self.block(left, right)
        compared = self.compare(blocked)
        scored = self.weights.score_compared(compared)
        scored = scored.mutate(
            prob=_.odds / (1 + _.odds), prediction=_.odds > self.odds_threshold
        )
        if "label_true_l" in scored.columns:
            scored = scored.mutate(
                actual=ibis.or_(
                    _.record_id_l == _.label_true_r, _.record_id_r == _.label_true_l
                )
            )
        assert "prob" in scored.columns
        return scored.cache()

    def _is_prepared(self, table: ir.Table) -> bool:
        """Check if the table has been prepared by this model."""
        p_test = self.prepare(table.limit(1))
        return p_test.columns == table.columns

    @property
    def comparers(self):
        return [dim.comparer for dim in self.dimensions]

    @property
    def fitted(self):
        return self.weights is not None

    def _validate_true(self, left: ir.Table, right: ir.Table):
        """Check that the tables have a 'label_true' column."""
        if "label_true" not in left.columns:
            raise ValueError("Left Table must have a 'label_true' column")
        if right is not None and "label_true" not in right.columns:
            raise ValueError("Right Table must have a 'label_true' column")

    def _validate_ids(self, left: ir.Table, right: ir.Table):
        """Check that the tables have a 'record_id' column."""
        if "record_id" not in left.columns:
            raise ValueError("Left Table must have a 'record_id' column")
        if right is not None and "record_id" not in right.columns:
            raise ValueError("Right Table must have a 'record_id' column")

    def fit(self, left: ir.Table, right: ir.Table = None, max_pairs: int = 10_000_000):
        """Fit the model to the data.

        Parameters
        ----------
        left : ir.Table
            The left table to use for training.
        right : ir.Table, optional
            The right table to use for training. If not provided, the left table is used.
        max_pairs : int, optional
            The maximum number of random pairs to use for the Fellegi-Sunter model.
        """
        self._validate_ids(left, right)
        self._validate_true(left, right)
        left_prepared = self.prepare(left).cache()
        if right is not None:
            right_prepared = self.prepare(right).cache()
        else:
            right_prepared = left_prepared
        self.weights = train_using_labels(
            self.comparers,
            left_prepared,
            right_prepared,
            max_pairs=max_pairs,
        )
        self.odds_threshold = self._find_optimal_threshold(
            left_prepared, right_prepared
        )
        return self

    def _find_optimal_threshold(self, left: ir.Table, right: ir.Table):
        """Find the optimal odds-threshold for the model by maximising the F1 score on known data."""
        pairs = self.score(left, right)
        true_labels = pairs.actual.execute()
        predicted_probs = pairs.prob.execute()

        fpr, tpr, thresholds = roc_curve(true_labels, predicted_probs)

        optimal_idx = np.argmax(tpr - fpr)  # Index where TPR - FPR is maximized
        optimal_threshold = thresholds[optimal_idx]
        odds_threshold = optimal_threshold / (1 - optimal_threshold)

        precision, recall, pr_thresholds = precision_recall_curve(
            true_labels, predicted_probs
        )
        f1_scores = 2 * (precision * recall) / (precision + recall)
        optimal_idx = np.argmax(f1_scores)
        pr_optimal_threshold = pr_thresholds[optimal_idx]
        pr_odds_threshold = pr_optimal_threshold / (1 - pr_optimal_threshold)

        if self._threshold_method == "pr":
            return pr_odds_threshold
        else:
            return odds_threshold

    def predict_proba(self, table: ir.Table):
        """Predict the probability of a match for each pair of records."""
        scored = self.score(table)
        return scored.select("prob")

    def predict(self, left: ir.Table, right=None, threshold: float = None):
        threshold = threshold or self.odds_threshold
        if threshold is None:
            raise ValueError("Threshold must be provided or model must be fitted")
        scored = self.score(left, right)
        links = scored.filter(_.odds > threshold).select("record_id_l", "record_id_r")
        if links.count().execute() == 0:
            print("No links found")
            links = (
                scored.select("record_id")
                .mutate(record_id_l=_.record_id, record_id_r=_.record_id)
                .select("record_id_l", "record_id_r")
            )
        links = links.cache()
        records = left.cache()
        clusters = connected_components(
            links=links,
            records=records,
        )
        clusters = clusters.group_by("component").mutate(prediction=_.count() > 1)
        if "label_true" in left.columns:
            clusters = clusters.group_by("label_true").mutate(actual=_.count() > 1)
        return clusters

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants