Skip to content

Commit

Permalink
remove matchop (except notebook, wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Aug 30, 2024
1 parent 6a08a10 commit 3e2c0a1
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 64 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ index_ref = indxr_pipe.index(dataset.get_corpus_iter(), batch_size=128)
# Retrieval

Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights.
We apply this as a query encoding transformer. It encodes the query into Terrier's matchop query language, to avoid tokenisation problems.
We apply this as a query encoding transformer.

```python

splade_retr = splade.query_encoder(matchop=True) >> pt.terrier.Retrieve('./msmarco_psg', wmodel='Tf')
splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf')

```

Expand Down
4 changes: 2 additions & 2 deletions pyt_splade/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from pyt_splade._model import Splade
from pyt_splade._encoder import SpladeEncoder
from pyt_splade._scorer import SpladeScorer
from pyt_splade._utils import Toks2Doc, MatchOp
from pyt_splade._utils import Toks2Doc

SpladeFactory = Splade # backward compatible name
toks2doc = Toks2Doc # backward compatible name

__all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc', 'MatchOp']
__all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc']
11 changes: 3 additions & 8 deletions pyt_splade/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,27 +58,22 @@ def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=Fa
out_field = 'toks' if sparse else 'doc_vec'
return pyt_splade.SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale)

indexing = doc_encoder # backward compatible name
indexing = doc_encoder # backward compatible name

def query_encoder(self, batch_size=100, sparse=True, verbose=False, matchop=False, scale=100) -> pt.Transformer:
def query_encoder(self, batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer:
"""Returns a transformer that encodes a query field into a query representation.
Args:
batch_size: the batch size to use when encoding
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector
verbose: if True, show a progress bar
matchop: if True, convert the output to MatchOp syntax
scale: the scale to apply to the term frequencies
"""
out_field = 'query_toks' if sparse else 'query_vec'
res = pyt_splade.SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale)
if matchop:
res = res >> pyt_splade.MatchOp()
return res

def query(self, batch_size=100, sparse=True, verbose=False, matchop=True, scale=100) -> pt.Transformer:
# backward compatible name w/ default matchop=True
return self.query_encoder(batch_size, sparse, verbose, matchop, scale)
query = query_encoder # backward compatible name

def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer:
"""Returns a transformer that scores documents against queries.
Expand Down
25 changes: 0 additions & 25 deletions pyt_splade/_utils.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,7 @@
import base64
import string
import pandas as pd
import pyterrier_alpha as pta
import pyterrier as pt


class MatchOp(pt.Transformer):
"""Converts a query_toks field into a query field, using the MatchOp syntax."""

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Converts the query_toks field into a query field."""
pta.validate.query_frame(df, ['query_toks'])
rtr = pt.model.push_queries(df)
rtr = rtr.assign(query=df.query_toks.apply(lambda toks: ' '.join(_matchop(k, v) for k, v in toks.items())))
rtr = rtr.drop(columns=['query_toks'])
return rtr


def _matchop(t, w):
"""Converts a term and its weight into MatchOp syntax."""
if not all(a in string.ascii_letters + string.digits for a in t):
encoded = base64.b64encode(t.encode('utf-8')).decode("utf-8")
t = f'#base64({encoded})'
if w != 1:
t = f'#combine:0={w}({t})'
return t


class Toks2Doc(pt.Transformer):
"""Converts a toks field into a text field, by scaling the weights by ``mult`` and repeating them."""
def __init__(self, mult: float = 100.):
Expand Down
12 changes: 1 addition & 11 deletions pyt_splade/pt_docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,12 @@ API Documentation
.. autoclass:: pyt_splade.Splade
:members:

Utils
Utils / Internals
------------------------------------------

These utility transformers allow you to convert between sparse representation formats.

.. autoclass:: pyt_splade.Toks2Doc
:members:

.. autoclass:: pyt_splade.MatchOp
:members:

Internals
------------------------------------------

These transformers are returned by :class:`~pyt_splade.Splade` to perform encoding and scoring.

.. autoclass:: pyt_splade.SpladeEncoder
:members:

Expand Down
4 changes: 2 additions & 2 deletions pyt_splade/pt_docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@ Retrieval
---------------------------------------------

Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights.
We apply this as a query encoding transformer. It encodes the query into Terrier's matchop query language, to avoid tokenisation problems.
We apply this as a query encoding transformer.

.. code-block:: python
splade_retr = splade.query_encoder(matchop=True) >> pt.terrier.Retrieve('./msmarco_psg', wmodel='Tf')
splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf')
Scoring
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
license="Creative Commons Attribution-NonCommercial-ShareAlike",
long_description=readme,
install_requires=[
'splade', 'python-terrier', 'pyterrier_alpha',
'splade', 'python-terrier>=0.11.0', 'pyterrier_alpha',
],
)
13 changes: 0 additions & 13 deletions tests/test_matchop.py

This file was deleted.

0 comments on commit 3e2c0a1

Please sign in to comment.