Skip to content

Commit

Permalink
refactoring and integrated documentation (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney authored Sep 2, 2024
1 parent 53489b2 commit 773dbb6
Show file tree
Hide file tree
Showing 15 changed files with 801 additions and 29,037 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include pyt_splade *.rst
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ index_ref = indxr_pipe.index(dataset.get_corpus_iter(), batch_size=128)
# Retrieval

Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights.
We apply this as a query encoding transformer. It encodes the query into Terrier's matchop query language, to avoid tokenisation problems.
We apply this as a query encoding transformer.

```python

splade_retr = splade.query_encoder(matchop=True) >> pt.terrier.Retrieve('./msmarco_psg', wmodel='Tf')
splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf')

```

Expand Down
29,190 changes: 382 additions & 28,808 deletions msmarco-psg-v1.ipynb

Large diffs are not rendered by default.

Binary file added pyt_splade/.DS_Store
Binary file not shown.
190 changes: 8 additions & 182 deletions pyt_splade/__init__.py
Original file line number Diff line number Diff line change
@@ -1,185 +1,11 @@
import base64
import string
import more_itertools
import pyterrier as pt
__version__ = '0.0.2'

assert pt.started()
from typing import Union
import torch
import numpy as np
import pandas as pd
from pyt_splade._model import Splade
from pyt_splade._encoder import SpladeEncoder
from pyt_splade._scorer import SpladeScorer
from pyt_splade._utils import Toks2Doc

SpladeFactory = Splade # backward compatible name
toks2doc = Toks2Doc # backward compatible name

class Splade():

def __init__(
self,
model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil",
tokenizer=None,
agg='max',
max_length=256,
device=None):
self.max_length = max_length
self.model = model
self.tokenizer = tokenizer
if device is None:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
else:
self.device = torch.device(device)
if isinstance(model, str):
from splade.models.transformer_rep import Splade
if self.tokenizer is None:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = Splade(model, agg=agg)
self.model.eval()
self.model = self.model.to(self.device)
else:
if self.tokenizer is None:
raise ValueError("you must specify tokenizer if passing a model")

self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}

def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer:
out_field = 'toks' if sparse else 'doc_vec'
return SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale)

indexing = doc_encoder # backward compatible name

def query_encoder(self, batch_size=100, sparse=True, verbose=False, matchop=False, scale=100) -> pt.Transformer:
out_field = 'query_toks' if sparse else 'query_vec'
res = SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale)
if matchop:
res = res >> MatchOp()
return res

def query(self, batch_size=100, sparse=True, verbose=False, matchop=True, scale=100) -> pt.Transformer:
# backward compatible name w/ default matchop=True
return self.query_encoder(batch_size, sparse, verbose, matchop, scale)

def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer:
return SpladeScorer(self, text_field, batch_size, verbose)

def encode(self, texts, rep='d', format='dict', scale=1.):
rtr = []
with torch.no_grad():
reps = self.model(**{rep + '_kwargs': self.tokenizer(
texts,
add_special_tokens=True,
padding="longest", # pad to max sequence length in batch
truncation="longest_first", # truncates to max model length,
max_length=self.max_length,
return_attention_mask=True,
return_tensors="pt",
).to(self.device)})[rep + '_rep']
reps = reps * scale
if format == 'dict':
reps = reps.cpu()
for i in range(reps.shape[0]):
# get the number of non-zero dimensions in the rep:
col = torch.nonzero(reps[i]).squeeze(1).tolist()
# now let's create the bow representation as a dictionary
weights = reps[i, col].cpu().tolist()
# if document cast to int to make the weights ready for terrier indexing
if rep == "d":
weights = list(map(int, weights))
sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0]))
# create the dict removing the weights less than 1, i.e. 0, that are not helpful
d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0}
rtr.append(d)
elif format == 'np':
reps = reps.cpu().numpy()
for i in range(reps.shape[0]):
rtr.append(reps[i])
elif format == 'torch':
rtr = reps
return rtr


SpladeFactory = Splade # backward compatible name


class SpladeEncoder(pt.Transformer):
def __init__(self, splade, text_field, out_field, rep, sparse=True, batch_size=100, verbose=False, scale=1.):
self.splade = splade
self.text_field = text_field
self.out_field = out_field
self.rep = rep
self.sparse = sparse
self.batch_size = batch_size
self.verbose = verbose
self.scale = scale

def transform(self, df):
assert self.text_field in df.columns
it = iter(df[self.text_field])
if self.verbose:
it = pt.tqdm(it, total=len(df), unit=self.text_field)
res = []
for batch in more_itertools.chunked(it, self.batch_size):
res.extend(self.splade.encode(batch, self.rep, format='dict' if self.sparse else 'np', scale=self.scale))
return df.assign(**{self.out_field: res})


class SpladeScorer(pt.Transformer):
def __init__(self, splade, text_field, batch_size=100, verbose=False):
self.splade = splade
self.text_field = text_field
self.batch_size = batch_size
self.verbose = verbose

def transform(self, df):
assert all(f in df.columns for f in ['query', self.text_field])
it = df.groupby('query')
if self.verbose:
it = pt.tqdm(it, unit='query')
res = []
for query, df in it:
query_enc = self.splade.encode([query], 'q', 'torch')
scores = []
for batch in more_itertools.chunked(df[self.text_field], self.batch_size):
doc_enc = self.splade.encode(batch, 'd', 'torch')
scores.append((query_enc @ doc_enc.T).flatten().cpu().numpy())
res.append(df.assign(score=np.concatenate(scores)))
res = pd.concat(res)
from pyterrier.model import add_ranks
res = add_ranks(res)
return res


class MatchOp(pt.Transformer):

def transform(self, df):
assert 'query_toks' in df.columns
from pyterrier.model import push_queries
rtr = push_queries(df)
rtr = rtr.assign(
query=df.query_toks.apply(lambda toks: ' '.join(_matchop(k, v) for k, v in toks.items())))
rtr = rtr.drop(columns=['query_toks'])
return rtr


def _matchop(t, w):
if not all(a in string.ascii_letters + string.digits for a in t):
encoded = base64.b64encode(t.encode('utf-8')).decode("utf-8")
t = f'#base64({encoded})'
if w != 1:
t = f'#combine:0={w}({t})'
return t


def toks2doc(mult=100):
def _dict_tf2text(tfdict):
rtr = ""
for t in tfdict:
for i in range(int(mult * tfdict[t])):
rtr += t + " "
return rtr

def _rowtransform(df):
df = df.copy()
df["text"] = df['toks'].apply(_dict_tf2text)
df.drop(columns=['toks'], inplace=True)
return df

return pt.apply.generic(_rowtransform)
__all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc']
53 changes: 53 additions & 0 deletions pyt_splade/_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Literal
import more_itertools
import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta
import pyt_splade


class SpladeEncoder(pt.Transformer):
"""Encodes a text field using a SPLADE model. The output is a dense or sparse representation of the text field."""

def __init__(
self,
splade: pyt_splade.Splade,
text_field: str,
out_field: str,
rep: Literal['q', 'd'],
sparse: bool = True,
batch_size: int = 100,
verbose: bool = False,
scale: float = 1.,
):
"""Initializes the SPLADE encoder.
Args:
splade: :class:`pyt_splade.Splade` instance
text_field: the input text field to encode
out_field: the output field to store the encoded representation
rep: 'q' for query, 'd' for document
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector
batch_size: the batch size to use when encoding
verbose: if True, show a progress bar
scale: the scale to apply to the term frequencies
"""
self.splade = splade
self.text_field = text_field
self.out_field = out_field
self.rep = rep
self.sparse = sparse
self.batch_size = batch_size
self.verbose = verbose
self.scale = scale

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Encodes the text field in the input DataFrame."""
pta.validate.columns(df, includes=[self.text_field])
it = iter(df[self.text_field])
if self.verbose:
it = pt.tqdm(it, total=len(df), unit=self.text_field)
res = []
for batch in more_itertools.chunked(it, self.batch_size):
res.extend(self.splade.encode(batch, self.rep, format='dict' if self.sparse else 'np', scale=self.scale))
return df.assign(**{self.out_field: res})
135 changes: 135 additions & 0 deletions pyt_splade/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Union, List, Literal, Dict
import torch
import numpy as np
import pyterrier as pt
import pyt_splade

class Splade:
"""A SPLADE model, which provides transformers for sparse encoding documents and queries, and scoring documents."""

def __init__(
self,
model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil",
tokenizer=None,
agg='max',
max_length=256,
device=None
):
"""Initializes the SPLADE model.
Args:
model: the SPLADE model to use, either a PyTorch model or a string to load from HuggingFace
tokenizer: the tokenizer to use, if not included in the model
agg: the aggregation function to use for the SPLADE model
max_length: the maximum length of the input sequences
device: the device to use, e.g. 'cuda' or 'cpu'
"""
self.max_length = max_length
self.model = model
self.tokenizer = tokenizer
if device is None:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
else:
self.device = torch.device(device)
if isinstance(model, str):
from splade.models.transformer_rep import Splade
if self.tokenizer is None:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = Splade(model, agg=agg)
self.model.eval()
self.model = self.model.to(self.device)
else:
if self.tokenizer is None:
raise ValueError("you must specify tokenizer if passing a model")

self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}

def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer:
"""Returns a transformer that encodes a text field into a document representation.
Args:
text_field: the text field to encode
batch_size: the batch size to use when encoding
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector
verbose: if True, show a progress bar
scale: the scale to apply to the term frequencies
"""
out_field = 'toks' if sparse else 'doc_vec'
return pyt_splade.SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale)

indexing = doc_encoder # backward compatible name

def query_encoder(self, batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer:
"""Returns a transformer that encodes a query field into a query representation.
Args:
batch_size: the batch size to use when encoding
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector
verbose: if True, show a progress bar
scale: the scale to apply to the term frequencies
"""
out_field = 'query_toks' if sparse else 'query_vec'
res = pyt_splade.SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale)
return res

query = query_encoder # backward compatible name

def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer:
"""Returns a transformer that scores documents against queries.
Args:
text_field: the text field to score
batch_size: the batch size to use when scoring
verbose: if True, show a progress bar
"""
return pyt_splade.SpladeScorer(self, text_field, batch_size, verbose)

def encode(
self,
texts: List[str],
rep: Literal['d', 'q'] = 'd',
format: Literal['dict', 'np', 'torch'] ='dict',
scale: float = 1.,
) -> Union[List[Dict[str, float]], List[np.ndarray], torch.Tensor]:
"""Encodes a batch of texts into their SPLADE representations.
Args:
texts: the list of texts to encode
rep: 'q' for query, 'd' for document
format: 'dict' for a dict of term frequencies, 'np' for a list of numpy arrays, 'torch' for a torch tensor
scale: the scale to apply to the term frequencies
"""
rtr = []
with torch.no_grad():
reps = self.model(**{rep + '_kwargs': self.tokenizer(
texts,
add_special_tokens=True,
padding="longest", # pad to max sequence length in batch
truncation="longest_first", # truncates to max model length,
max_length=self.max_length,
return_attention_mask=True,
return_tensors="pt",
).to(self.device)})[rep + '_rep']
reps = reps * scale
if format == 'dict':
reps = reps.cpu()
for i in range(reps.shape[0]):
# get the number of non-zero dimensions in the rep:
col = torch.nonzero(reps[i]).squeeze(1).tolist()
# now let's create the bow representation as a dictionary
weights = reps[i, col].cpu().tolist()
# if document cast to int to make the weights ready for terrier indexing
if rep == "d":
weights = list(map(int, weights))
sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0]))
# create the dict removing the weights less than 1, i.e. 0, that are not helpful
d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0}
rtr.append(d)
elif format == 'np':
reps = reps.cpu().numpy()
for i in range(reps.shape[0]):
rtr.append(reps[i])
elif format == 'torch':
rtr = reps
return rtr
Loading

0 comments on commit 773dbb6

Please sign in to comment.