Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactoring and integrated documentation #6

Merged
merged 6 commits into from
Sep 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
recursive-include pyt_splade *.rst
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ index_ref = indxr_pipe.index(dataset.get_corpus_iter(), batch_size=128)
# Retrieval

Similarly, SPLADE encodes the query into BERT WordPieces and corresponding weights.
We apply this as a query encoding transformer. It encodes the query into Terrier's matchop query language, to avoid tokenisation problems.
We apply this as a query encoding transformer.

```python

splade_retr = splade.query_encoder(matchop=True) >> pt.terrier.Retrieve('./msmarco_psg', wmodel='Tf')
splade_retr = splade.query_encoder() >> pt.terrier.Retriever('./msmarco_psg', wmodel='Tf')

```

Expand Down
29,190 changes: 382 additions & 28,808 deletions msmarco-psg-v1.ipynb

Large diffs are not rendered by default.

Binary file added pyt_splade/.DS_Store
Binary file not shown.
190 changes: 8 additions & 182 deletions pyt_splade/__init__.py
Original file line number Diff line number Diff line change
@@ -1,185 +1,11 @@
import base64
import string
import more_itertools
import pyterrier as pt
__version__ = '0.0.2'

assert pt.started()
from typing import Union
import torch
import numpy as np
import pandas as pd
from pyt_splade._model import Splade
from pyt_splade._encoder import SpladeEncoder
from pyt_splade._scorer import SpladeScorer
from pyt_splade._utils import Toks2Doc

SpladeFactory = Splade # backward compatible name
toks2doc = Toks2Doc # backward compatible name

class Splade():

def __init__(
self,
model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil",
tokenizer=None,
agg='max',
max_length=256,
device=None):
self.max_length = max_length
self.model = model
self.tokenizer = tokenizer
if device is None:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
else:
self.device = torch.device(device)
if isinstance(model, str):
from splade.models.transformer_rep import Splade
if self.tokenizer is None:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = Splade(model, agg=agg)
self.model.eval()
self.model = self.model.to(self.device)
else:
if self.tokenizer is None:
raise ValueError("you must specify tokenizer if passing a model")

self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}

def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer:
out_field = 'toks' if sparse else 'doc_vec'
return SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale)

indexing = doc_encoder # backward compatible name

def query_encoder(self, batch_size=100, sparse=True, verbose=False, matchop=False, scale=100) -> pt.Transformer:
out_field = 'query_toks' if sparse else 'query_vec'
res = SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale)
if matchop:
res = res >> MatchOp()
return res

def query(self, batch_size=100, sparse=True, verbose=False, matchop=True, scale=100) -> pt.Transformer:
# backward compatible name w/ default matchop=True
return self.query_encoder(batch_size, sparse, verbose, matchop, scale)

def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer:
return SpladeScorer(self, text_field, batch_size, verbose)

def encode(self, texts, rep='d', format='dict', scale=1.):
rtr = []
with torch.no_grad():
reps = self.model(**{rep + '_kwargs': self.tokenizer(
texts,
add_special_tokens=True,
padding="longest", # pad to max sequence length in batch
truncation="longest_first", # truncates to max model length,
max_length=self.max_length,
return_attention_mask=True,
return_tensors="pt",
).to(self.device)})[rep + '_rep']
reps = reps * scale
if format == 'dict':
reps = reps.cpu()
for i in range(reps.shape[0]):
# get the number of non-zero dimensions in the rep:
col = torch.nonzero(reps[i]).squeeze(1).tolist()
# now let's create the bow representation as a dictionary
weights = reps[i, col].cpu().tolist()
# if document cast to int to make the weights ready for terrier indexing
if rep == "d":
weights = list(map(int, weights))
sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0]))
# create the dict removing the weights less than 1, i.e. 0, that are not helpful
d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0}
rtr.append(d)
elif format == 'np':
reps = reps.cpu().numpy()
for i in range(reps.shape[0]):
rtr.append(reps[i])
elif format == 'torch':
rtr = reps
return rtr


SpladeFactory = Splade # backward compatible name


class SpladeEncoder(pt.Transformer):
def __init__(self, splade, text_field, out_field, rep, sparse=True, batch_size=100, verbose=False, scale=1.):
self.splade = splade
self.text_field = text_field
self.out_field = out_field
self.rep = rep
self.sparse = sparse
self.batch_size = batch_size
self.verbose = verbose
self.scale = scale

def transform(self, df):
assert self.text_field in df.columns
it = iter(df[self.text_field])
if self.verbose:
it = pt.tqdm(it, total=len(df), unit=self.text_field)
res = []
for batch in more_itertools.chunked(it, self.batch_size):
res.extend(self.splade.encode(batch, self.rep, format='dict' if self.sparse else 'np', scale=self.scale))
return df.assign(**{self.out_field: res})


class SpladeScorer(pt.Transformer):
def __init__(self, splade, text_field, batch_size=100, verbose=False):
self.splade = splade
self.text_field = text_field
self.batch_size = batch_size
self.verbose = verbose

def transform(self, df):
assert all(f in df.columns for f in ['query', self.text_field])
it = df.groupby('query')
if self.verbose:
it = pt.tqdm(it, unit='query')
res = []
for query, df in it:
query_enc = self.splade.encode([query], 'q', 'torch')
scores = []
for batch in more_itertools.chunked(df[self.text_field], self.batch_size):
doc_enc = self.splade.encode(batch, 'd', 'torch')
scores.append((query_enc @ doc_enc.T).flatten().cpu().numpy())
res.append(df.assign(score=np.concatenate(scores)))
res = pd.concat(res)
from pyterrier.model import add_ranks
res = add_ranks(res)
return res


class MatchOp(pt.Transformer):

def transform(self, df):
assert 'query_toks' in df.columns
from pyterrier.model import push_queries
rtr = push_queries(df)
rtr = rtr.assign(
query=df.query_toks.apply(lambda toks: ' '.join(_matchop(k, v) for k, v in toks.items())))
rtr = rtr.drop(columns=['query_toks'])
return rtr


def _matchop(t, w):
if not all(a in string.ascii_letters + string.digits for a in t):
encoded = base64.b64encode(t.encode('utf-8')).decode("utf-8")
t = f'#base64({encoded})'
if w != 1:
t = f'#combine:0={w}({t})'
return t


def toks2doc(mult=100):
def _dict_tf2text(tfdict):
rtr = ""
for t in tfdict:
for i in range(int(mult * tfdict[t])):
rtr += t + " "
return rtr

def _rowtransform(df):
df = df.copy()
df["text"] = df['toks'].apply(_dict_tf2text)
df.drop(columns=['toks'], inplace=True)
return df

return pt.apply.generic(_rowtransform)
__all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc']
53 changes: 53 additions & 0 deletions pyt_splade/_encoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Literal
import more_itertools
import pandas as pd
import pyterrier as pt
import pyterrier_alpha as pta
import pyt_splade


class SpladeEncoder(pt.Transformer):
"""Encodes a text field using a SPLADE model. The output is a dense or sparse representation of the text field."""

def __init__(
self,
splade: pyt_splade.Splade,
text_field: str,
out_field: str,
rep: Literal['q', 'd'],
sparse: bool = True,
batch_size: int = 100,
verbose: bool = False,
scale: float = 1.,
):
"""Initializes the SPLADE encoder.

Args:
splade: :class:`pyt_splade.Splade` instance
text_field: the input text field to encode
out_field: the output field to store the encoded representation
rep: 'q' for query, 'd' for document
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector
batch_size: the batch size to use when encoding
verbose: if True, show a progress bar
scale: the scale to apply to the term frequencies
"""
self.splade = splade
self.text_field = text_field
self.out_field = out_field
self.rep = rep
self.sparse = sparse
self.batch_size = batch_size
self.verbose = verbose
self.scale = scale

def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Encodes the text field in the input DataFrame."""
pta.validate.columns(df, includes=[self.text_field])
it = iter(df[self.text_field])
if self.verbose:
it = pt.tqdm(it, total=len(df), unit=self.text_field)
res = []
for batch in more_itertools.chunked(it, self.batch_size):
res.extend(self.splade.encode(batch, self.rep, format='dict' if self.sparse else 'np', scale=self.scale))
return df.assign(**{self.out_field: res})
135 changes: 135 additions & 0 deletions pyt_splade/_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Union, List, Literal, Dict
import torch
import numpy as np
import pyterrier as pt
import pyt_splade

class Splade:
"""A SPLADE model, which provides transformers for sparse encoding documents and queries, and scoring documents."""

def __init__(
self,
model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil",
tokenizer=None,
agg='max',
max_length=256,
device=None
):
"""Initializes the SPLADE model.

Args:
model: the SPLADE model to use, either a PyTorch model or a string to load from HuggingFace
tokenizer: the tokenizer to use, if not included in the model
agg: the aggregation function to use for the SPLADE model
max_length: the maximum length of the input sequences
device: the device to use, e.g. 'cuda' or 'cpu'
"""
self.max_length = max_length
self.model = model
self.tokenizer = tokenizer
if device is None:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
else:
self.device = torch.device(device)
if isinstance(model, str):
from splade.models.transformer_rep import Splade
if self.tokenizer is None:
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = Splade(model, agg=agg)
self.model.eval()
self.model = self.model.to(self.device)
else:
if self.tokenizer is None:
raise ValueError("you must specify tokenizer if passing a model")

self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}

def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer:
"""Returns a transformer that encodes a text field into a document representation.

Args:
text_field: the text field to encode
batch_size: the batch size to use when encoding
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector
verbose: if True, show a progress bar
scale: the scale to apply to the term frequencies
"""
out_field = 'toks' if sparse else 'doc_vec'
return pyt_splade.SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale)

indexing = doc_encoder # backward compatible name

def query_encoder(self, batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer:
"""Returns a transformer that encodes a query field into a query representation.

Args:
batch_size: the batch size to use when encoding
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector
verbose: if True, show a progress bar
scale: the scale to apply to the term frequencies
"""
out_field = 'query_toks' if sparse else 'query_vec'
res = pyt_splade.SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale)
return res

query = query_encoder # backward compatible name

def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer:
"""Returns a transformer that scores documents against queries.

Args:
text_field: the text field to score
batch_size: the batch size to use when scoring
verbose: if True, show a progress bar
"""
return pyt_splade.SpladeScorer(self, text_field, batch_size, verbose)

def encode(
self,
texts: List[str],
rep: Literal['d', 'q'] = 'd',
format: Literal['dict', 'np', 'torch'] ='dict',
scale: float = 1.,
) -> Union[List[Dict[str, float]], List[np.ndarray], torch.Tensor]:
"""Encodes a batch of texts into their SPLADE representations.

Args:
texts: the list of texts to encode
rep: 'q' for query, 'd' for document
format: 'dict' for a dict of term frequencies, 'np' for a list of numpy arrays, 'torch' for a torch tensor
scale: the scale to apply to the term frequencies
"""
rtr = []
with torch.no_grad():
reps = self.model(**{rep + '_kwargs': self.tokenizer(
texts,
add_special_tokens=True,
padding="longest", # pad to max sequence length in batch
truncation="longest_first", # truncates to max model length,
max_length=self.max_length,
return_attention_mask=True,
return_tensors="pt",
).to(self.device)})[rep + '_rep']
reps = reps * scale
if format == 'dict':
reps = reps.cpu()
for i in range(reps.shape[0]):
# get the number of non-zero dimensions in the rep:
col = torch.nonzero(reps[i]).squeeze(1).tolist()
# now let's create the bow representation as a dictionary
weights = reps[i, col].cpu().tolist()
# if document cast to int to make the weights ready for terrier indexing
if rep == "d":
weights = list(map(int, weights))
sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0]))
# create the dict removing the weights less than 1, i.e. 0, that are not helpful
d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0}
rtr.append(d)
elif format == 'np':
reps = reps.cpu().numpy()
for i in range(reps.shape[0]):
rtr.append(reps[i])
elif format == 'torch':
rtr = reps
return rtr
Loading