-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactoring and integrated documentation (#6)
- Loading branch information
1 parent
53489b2
commit 773dbb6
Showing
15 changed files
with
801 additions
and
29,037 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
recursive-include pyt_splade *.rst |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,185 +1,11 @@ | ||
import base64 | ||
import string | ||
import more_itertools | ||
import pyterrier as pt | ||
__version__ = '0.0.2' | ||
|
||
assert pt.started() | ||
from typing import Union | ||
import torch | ||
import numpy as np | ||
import pandas as pd | ||
from pyt_splade._model import Splade | ||
from pyt_splade._encoder import SpladeEncoder | ||
from pyt_splade._scorer import SpladeScorer | ||
from pyt_splade._utils import Toks2Doc | ||
|
||
SpladeFactory = Splade # backward compatible name | ||
toks2doc = Toks2Doc # backward compatible name | ||
|
||
class Splade(): | ||
|
||
def __init__( | ||
self, | ||
model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil", | ||
tokenizer=None, | ||
agg='max', | ||
max_length=256, | ||
device=None): | ||
self.max_length = max_length | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
if device is None: | ||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
else: | ||
self.device = torch.device(device) | ||
if isinstance(model, str): | ||
from splade.models.transformer_rep import Splade | ||
if self.tokenizer is None: | ||
from transformers import AutoTokenizer | ||
self.tokenizer = AutoTokenizer.from_pretrained(model) | ||
self.model = Splade(model, agg=agg) | ||
self.model.eval() | ||
self.model = self.model.to(self.device) | ||
else: | ||
if self.tokenizer is None: | ||
raise ValueError("you must specify tokenizer if passing a model") | ||
|
||
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} | ||
|
||
def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: | ||
out_field = 'toks' if sparse else 'doc_vec' | ||
return SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale) | ||
|
||
indexing = doc_encoder # backward compatible name | ||
|
||
def query_encoder(self, batch_size=100, sparse=True, verbose=False, matchop=False, scale=100) -> pt.Transformer: | ||
out_field = 'query_toks' if sparse else 'query_vec' | ||
res = SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale) | ||
if matchop: | ||
res = res >> MatchOp() | ||
return res | ||
|
||
def query(self, batch_size=100, sparse=True, verbose=False, matchop=True, scale=100) -> pt.Transformer: | ||
# backward compatible name w/ default matchop=True | ||
return self.query_encoder(batch_size, sparse, verbose, matchop, scale) | ||
|
||
def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer: | ||
return SpladeScorer(self, text_field, batch_size, verbose) | ||
|
||
def encode(self, texts, rep='d', format='dict', scale=1.): | ||
rtr = [] | ||
with torch.no_grad(): | ||
reps = self.model(**{rep + '_kwargs': self.tokenizer( | ||
texts, | ||
add_special_tokens=True, | ||
padding="longest", # pad to max sequence length in batch | ||
truncation="longest_first", # truncates to max model length, | ||
max_length=self.max_length, | ||
return_attention_mask=True, | ||
return_tensors="pt", | ||
).to(self.device)})[rep + '_rep'] | ||
reps = reps * scale | ||
if format == 'dict': | ||
reps = reps.cpu() | ||
for i in range(reps.shape[0]): | ||
# get the number of non-zero dimensions in the rep: | ||
col = torch.nonzero(reps[i]).squeeze(1).tolist() | ||
# now let's create the bow representation as a dictionary | ||
weights = reps[i, col].cpu().tolist() | ||
# if document cast to int to make the weights ready for terrier indexing | ||
if rep == "d": | ||
weights = list(map(int, weights)) | ||
sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0])) | ||
# create the dict removing the weights less than 1, i.e. 0, that are not helpful | ||
d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0} | ||
rtr.append(d) | ||
elif format == 'np': | ||
reps = reps.cpu().numpy() | ||
for i in range(reps.shape[0]): | ||
rtr.append(reps[i]) | ||
elif format == 'torch': | ||
rtr = reps | ||
return rtr | ||
|
||
|
||
SpladeFactory = Splade # backward compatible name | ||
|
||
|
||
class SpladeEncoder(pt.Transformer): | ||
def __init__(self, splade, text_field, out_field, rep, sparse=True, batch_size=100, verbose=False, scale=1.): | ||
self.splade = splade | ||
self.text_field = text_field | ||
self.out_field = out_field | ||
self.rep = rep | ||
self.sparse = sparse | ||
self.batch_size = batch_size | ||
self.verbose = verbose | ||
self.scale = scale | ||
|
||
def transform(self, df): | ||
assert self.text_field in df.columns | ||
it = iter(df[self.text_field]) | ||
if self.verbose: | ||
it = pt.tqdm(it, total=len(df), unit=self.text_field) | ||
res = [] | ||
for batch in more_itertools.chunked(it, self.batch_size): | ||
res.extend(self.splade.encode(batch, self.rep, format='dict' if self.sparse else 'np', scale=self.scale)) | ||
return df.assign(**{self.out_field: res}) | ||
|
||
|
||
class SpladeScorer(pt.Transformer): | ||
def __init__(self, splade, text_field, batch_size=100, verbose=False): | ||
self.splade = splade | ||
self.text_field = text_field | ||
self.batch_size = batch_size | ||
self.verbose = verbose | ||
|
||
def transform(self, df): | ||
assert all(f in df.columns for f in ['query', self.text_field]) | ||
it = df.groupby('query') | ||
if self.verbose: | ||
it = pt.tqdm(it, unit='query') | ||
res = [] | ||
for query, df in it: | ||
query_enc = self.splade.encode([query], 'q', 'torch') | ||
scores = [] | ||
for batch in more_itertools.chunked(df[self.text_field], self.batch_size): | ||
doc_enc = self.splade.encode(batch, 'd', 'torch') | ||
scores.append((query_enc @ doc_enc.T).flatten().cpu().numpy()) | ||
res.append(df.assign(score=np.concatenate(scores))) | ||
res = pd.concat(res) | ||
from pyterrier.model import add_ranks | ||
res = add_ranks(res) | ||
return res | ||
|
||
|
||
class MatchOp(pt.Transformer): | ||
|
||
def transform(self, df): | ||
assert 'query_toks' in df.columns | ||
from pyterrier.model import push_queries | ||
rtr = push_queries(df) | ||
rtr = rtr.assign( | ||
query=df.query_toks.apply(lambda toks: ' '.join(_matchop(k, v) for k, v in toks.items()))) | ||
rtr = rtr.drop(columns=['query_toks']) | ||
return rtr | ||
|
||
|
||
def _matchop(t, w): | ||
if not all(a in string.ascii_letters + string.digits for a in t): | ||
encoded = base64.b64encode(t.encode('utf-8')).decode("utf-8") | ||
t = f'#base64({encoded})' | ||
if w != 1: | ||
t = f'#combine:0={w}({t})' | ||
return t | ||
|
||
|
||
def toks2doc(mult=100): | ||
def _dict_tf2text(tfdict): | ||
rtr = "" | ||
for t in tfdict: | ||
for i in range(int(mult * tfdict[t])): | ||
rtr += t + " " | ||
return rtr | ||
|
||
def _rowtransform(df): | ||
df = df.copy() | ||
df["text"] = df['toks'].apply(_dict_tf2text) | ||
df.drop(columns=['toks'], inplace=True) | ||
return df | ||
|
||
return pt.apply.generic(_rowtransform) | ||
__all__ = ['Splade', 'SpladeEncoder', 'SpladeScorer', 'SpladeFactory', 'Toks2Doc', 'toks2doc'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from typing import Literal | ||
import more_itertools | ||
import pandas as pd | ||
import pyterrier as pt | ||
import pyterrier_alpha as pta | ||
import pyt_splade | ||
|
||
|
||
class SpladeEncoder(pt.Transformer): | ||
"""Encodes a text field using a SPLADE model. The output is a dense or sparse representation of the text field.""" | ||
|
||
def __init__( | ||
self, | ||
splade: pyt_splade.Splade, | ||
text_field: str, | ||
out_field: str, | ||
rep: Literal['q', 'd'], | ||
sparse: bool = True, | ||
batch_size: int = 100, | ||
verbose: bool = False, | ||
scale: float = 1., | ||
): | ||
"""Initializes the SPLADE encoder. | ||
Args: | ||
splade: :class:`pyt_splade.Splade` instance | ||
text_field: the input text field to encode | ||
out_field: the output field to store the encoded representation | ||
rep: 'q' for query, 'd' for document | ||
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector | ||
batch_size: the batch size to use when encoding | ||
verbose: if True, show a progress bar | ||
scale: the scale to apply to the term frequencies | ||
""" | ||
self.splade = splade | ||
self.text_field = text_field | ||
self.out_field = out_field | ||
self.rep = rep | ||
self.sparse = sparse | ||
self.batch_size = batch_size | ||
self.verbose = verbose | ||
self.scale = scale | ||
|
||
def transform(self, df: pd.DataFrame) -> pd.DataFrame: | ||
"""Encodes the text field in the input DataFrame.""" | ||
pta.validate.columns(df, includes=[self.text_field]) | ||
it = iter(df[self.text_field]) | ||
if self.verbose: | ||
it = pt.tqdm(it, total=len(df), unit=self.text_field) | ||
res = [] | ||
for batch in more_itertools.chunked(it, self.batch_size): | ||
res.extend(self.splade.encode(batch, self.rep, format='dict' if self.sparse else 'np', scale=self.scale)) | ||
return df.assign(**{self.out_field: res}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from typing import Union, List, Literal, Dict | ||
import torch | ||
import numpy as np | ||
import pyterrier as pt | ||
import pyt_splade | ||
|
||
class Splade: | ||
"""A SPLADE model, which provides transformers for sparse encoding documents and queries, and scoring documents.""" | ||
|
||
def __init__( | ||
self, | ||
model: Union[torch.nn.Module, str] = "naver/splade-cocondenser-ensembledistil", | ||
tokenizer=None, | ||
agg='max', | ||
max_length=256, | ||
device=None | ||
): | ||
"""Initializes the SPLADE model. | ||
Args: | ||
model: the SPLADE model to use, either a PyTorch model or a string to load from HuggingFace | ||
tokenizer: the tokenizer to use, if not included in the model | ||
agg: the aggregation function to use for the SPLADE model | ||
max_length: the maximum length of the input sequences | ||
device: the device to use, e.g. 'cuda' or 'cpu' | ||
""" | ||
self.max_length = max_length | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
if device is None: | ||
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | ||
else: | ||
self.device = torch.device(device) | ||
if isinstance(model, str): | ||
from splade.models.transformer_rep import Splade | ||
if self.tokenizer is None: | ||
from transformers import AutoTokenizer | ||
self.tokenizer = AutoTokenizer.from_pretrained(model) | ||
self.model = Splade(model, agg=agg) | ||
self.model.eval() | ||
self.model = self.model.to(self.device) | ||
else: | ||
if self.tokenizer is None: | ||
raise ValueError("you must specify tokenizer if passing a model") | ||
|
||
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()} | ||
|
||
def doc_encoder(self, text_field='text', batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: | ||
"""Returns a transformer that encodes a text field into a document representation. | ||
Args: | ||
text_field: the text field to encode | ||
batch_size: the batch size to use when encoding | ||
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector | ||
verbose: if True, show a progress bar | ||
scale: the scale to apply to the term frequencies | ||
""" | ||
out_field = 'toks' if sparse else 'doc_vec' | ||
return pyt_splade.SpladeEncoder(self, text_field, out_field, 'd', sparse, batch_size, verbose, scale) | ||
|
||
indexing = doc_encoder # backward compatible name | ||
|
||
def query_encoder(self, batch_size=100, sparse=True, verbose=False, scale=100) -> pt.Transformer: | ||
"""Returns a transformer that encodes a query field into a query representation. | ||
Args: | ||
batch_size: the batch size to use when encoding | ||
sparse: if True, the output will be a dict of term frequencies, otherwise a dense vector | ||
verbose: if True, show a progress bar | ||
scale: the scale to apply to the term frequencies | ||
""" | ||
out_field = 'query_toks' if sparse else 'query_vec' | ||
res = pyt_splade.SpladeEncoder(self, 'query', out_field, 'q', sparse, batch_size, verbose, scale) | ||
return res | ||
|
||
query = query_encoder # backward compatible name | ||
|
||
def scorer(self, text_field='text', batch_size=100, verbose=False) -> pt.Transformer: | ||
"""Returns a transformer that scores documents against queries. | ||
Args: | ||
text_field: the text field to score | ||
batch_size: the batch size to use when scoring | ||
verbose: if True, show a progress bar | ||
""" | ||
return pyt_splade.SpladeScorer(self, text_field, batch_size, verbose) | ||
|
||
def encode( | ||
self, | ||
texts: List[str], | ||
rep: Literal['d', 'q'] = 'd', | ||
format: Literal['dict', 'np', 'torch'] ='dict', | ||
scale: float = 1., | ||
) -> Union[List[Dict[str, float]], List[np.ndarray], torch.Tensor]: | ||
"""Encodes a batch of texts into their SPLADE representations. | ||
Args: | ||
texts: the list of texts to encode | ||
rep: 'q' for query, 'd' for document | ||
format: 'dict' for a dict of term frequencies, 'np' for a list of numpy arrays, 'torch' for a torch tensor | ||
scale: the scale to apply to the term frequencies | ||
""" | ||
rtr = [] | ||
with torch.no_grad(): | ||
reps = self.model(**{rep + '_kwargs': self.tokenizer( | ||
texts, | ||
add_special_tokens=True, | ||
padding="longest", # pad to max sequence length in batch | ||
truncation="longest_first", # truncates to max model length, | ||
max_length=self.max_length, | ||
return_attention_mask=True, | ||
return_tensors="pt", | ||
).to(self.device)})[rep + '_rep'] | ||
reps = reps * scale | ||
if format == 'dict': | ||
reps = reps.cpu() | ||
for i in range(reps.shape[0]): | ||
# get the number of non-zero dimensions in the rep: | ||
col = torch.nonzero(reps[i]).squeeze(1).tolist() | ||
# now let's create the bow representation as a dictionary | ||
weights = reps[i, col].cpu().tolist() | ||
# if document cast to int to make the weights ready for terrier indexing | ||
if rep == "d": | ||
weights = list(map(int, weights)) | ||
sorted_weights = sorted(zip(col, weights), key=lambda x: (-x[1], x[0])) | ||
# create the dict removing the weights less than 1, i.e. 0, that are not helpful | ||
d = {self.reverse_voc[k]: v for k, v in sorted_weights if v > 0} | ||
rtr.append(d) | ||
elif format == 'np': | ||
reps = reps.cpu().numpy() | ||
for i in range(reps.shape[0]): | ||
rtr.append(reps[i]) | ||
elif format == 'torch': | ||
rtr = reps | ||
return rtr |
Oops, something went wrong.