Skip to content

Commit

Permalink
feat(em): add basic EM
Browse files Browse the repository at this point in the history
still have many missing features:
#9
  • Loading branch information
NickCrews committed Oct 18, 2023
1 parent bdc2dc9 commit e34e149
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 0 deletions.
1 change: 1 addition & 0 deletions mismo/compare/fs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ._plot import plot_weights as plot_weights
from ._train import train_comparison as train_comparison
from ._train import train_comparisons as train_comparisons
from ._train_em import train_using_em as train_using_em
from ._weights import ComparisonWeights as ComparisonWeights
from ._weights import LevelWeights as LevelWeights
from ._weights import Weights as Weights
67 changes: 67 additions & 0 deletions mismo/compare/fs/_train_em.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from __future__ import annotations

from ibis import _
from ibis.expr.types import IntegerColumn, StringColumn, Table

from mismo.compare._comparison import Comparison, Comparisons
from mismo.compare.fs import _train
from mismo.compare.fs._weights import ComparisonWeights, Weights


def train_using_em(
comparisons: Comparisons,
left: Table,
right: Table,
max_pairs: int | None = None,
seed: int | None = None,
) -> Weights:
"""Train weights on unlabeled data using an expectation maximization algorithm."""
initial_blocking = _train.possible_pairs(
left, right, max_pairs=max_pairs, seed=seed
)
initial_compared: Table = comparisons.label_pairs(initial_blocking, how="name")
initial_compared = initial_compared[[c.name for c in comparisons]].cache()
weights = _initial_weights(comparisons, initial_compared)
for _i in range(5):
scored = weights.score(initial_compared)
is_match = _.odds >= 10
matches = scored.filter(is_match)
nonmatches = scored.filter(~is_match)
weights = _weights_from_matches_nonmatches(comparisons, matches, nonmatches)
return weights


def _initial_weights(comparisons: Comparisons, labels: Table) -> Weights:
return Weights(_initial_comparison_weights(c, labels[c.name]) for c in comparisons)


def _initial_comparison_weights(
comparison: Comparison, labels: IntegerColumn
) -> ComparisonWeights:
n_levels = len(comparison)
ms = [1 / n_levels] * n_levels
us = _train.level_proportions(comparison, labels)
return _train.make_weights(comparison, ms, us)


def _weights_from_matches_nonmatches(
comparisons: Comparisons, matches: Table, nonmatches: Table
) -> Weights:
return Weights(
[
_comparison_weights_from_matches_nonmatches(
comp, matches[comp.name], nonmatches[comp.name]
)
for comp in comparisons
]
)


def _comparison_weights_from_matches_nonmatches(
comparison: Comparison,
match_labels: IntegerColumn | StringColumn,
nonmatch_labels: IntegerColumn | StringColumn,
) -> ComparisonWeights:
ms = _train.level_proportions(comparison, match_labels)
us = _train.level_proportions(comparison, nonmatch_labels)
return _train.make_weights(comparison, ms, us)

0 comments on commit e34e149

Please sign in to comment.