Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Colbert PRF as a textual reranker #62

Open
cmacdonald opened this issue Mar 13, 2023 · 1 comment
Open

Colbert PRF as a textual reranker #62

cmacdonald opened this issue Mar 13, 2023 · 1 comment

Comments

@cmacdonald
Copy link
Collaborator

Maik Frobe requested Colbert prf as a textual reranker.

I think the code should look like this:

colbert = ColBERTModelOnlyFactory(checkpoint)
bm25 = pt.BatchRetrieve(sparse_index, wmodel='BM25', metadata=['docno', 'text'])
cprf_reranker = (
    bm25 
    >> colbert.text_encoder() 
    >> ColbertPRF(colbert, k=64, fb_embs=10, beta=1, fb_docs=10, return_docs=True) 
    >> colbert.scorer()
)

but: The only thing the index is used for is the token-level IDF, so we'd need to work around that...
https://github.com/terrierteam/pyterrier_colbert/blob/main/pyterrier_colbert/ranking.py#L1020-L1024

Cc/ @seanmacavaney

@cmacdonald cmacdonald changed the title Colbert PRF as a reranker Colbert PRF as a textual reranker Mar 13, 2023
@Xiao0728
Copy link
Collaborator

I have used the pipeline BM25>>ColBERT-PRF

from sklearn.cluster import KMeans
from pyterrier.transformer import TransformerBase
import pandas as pd
class ColBERTPRF_docencoded(TransformerBase):
    def __init__(self, k, exp_terms, beta=1, r = 42, mean_cos_weight=False, idf_weight=False,ictf_weight = False, probIDF_weight=False, return_docs = False, fb_docs=10, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k
        self.exp_terms = exp_terms
        self.beta = beta
        self.mean_cos_weight = mean_cos_weight
        self.idf_weight = idf_weight
        self.probIDF_weight = probIDF_weight
        self.return_docs = return_docs
        self.fb_docs = fb_docs
        self.r = r
        self.ictf_weight = ictf_weight
        assert self.k > self.exp_terms ,"exp_terms should be smaller than number of clusters"
    def _get_prf_embs(self, df, num_docs):
#         return torch.cat(df.head(num_docs).doc_embs.values)
        return torch.cat((df.head(num_docs).doc_embs.values).tolist(),dim=0)
    
    def transform_query(self, topic_and_res):
        topic_and_res = topic_and_res.sort_values('rank')
#         prf_embs = torch.cat((topic_and_res.head(self.fb_docs).doc_embs.values).tolist(),dim=0)      
        prf_embs = self._get_prf_embs(topic_and_res, self.fb_docs)
#         prf_embs = torch.cat([pytcolbert.rrm.get_embedding(docid) for docid in topic_and_res.head(self.fb_docs).docid.values])

        kmn =  KMeans(self.k, random_state=self.r)
        kmn.fit(prf_embs)
        
        emb_and_score = []
        for cluster in range(self.k):
            # take the centroid, needs to be the float32.
            centroid = np.float32( kmn.cluster_centers_[cluster] )
#             with open('centroid.pickle', 'wb') as handle:
#                 pickle.dump(centroid, handle)
            tok2freq = get_nearest_tokens_for_emb(fnt, centroid)
            if len(tok2freq) == 0:
                continue
            most_likely_tok = max(tok2freq, key=tok2freq.get)
            tid = fnt.inference.query_tokenizer.tok.convert_tokens_to_ids(most_likely_tok)
            
            if self.mean_cos_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, id2meancos[tid])) # meanCos score without normalisation
            
            elif self.idf_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, idfdict[tid]) ) # idf score without normalisation
#                 emb_and_score.append( (centroid, most_likely_tok, tid, idfGN[tid]) )  # idf score with global normalisation
            elif self.ictf_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, ictfdict[tid]) )
            elif self.probIDF_weight:
                emb_and_score.append( (centroid, most_likely_tok, tid, probIDF[tid]) )  # probIDF score without normalisation
#                 emb_and_score.append( (centroid, most_likely_tok, tid, probIDFGN[tid]) )  # probIDFGN is probIDF score with global normalisation
        
        
        sorted_by_second = sorted(emb_and_score, key=lambda tup: -tup[3])
        
        toks=[]
        scores=[]
        exp_embds = []
        for i in range(min(self.exp_terms, len(sorted_by_second))):
            emb, tok, tid, score = sorted_by_second[i]
            toks.append(tok)
            
#             score = score/sorted_by_second[0][3]# normalisation by the largest ==> per-query normalisation

            scores.append(score)
            exp_embds.append(emb)
        
        first_row = topic_and_res.iloc[0]
        newemb = torch.cat([
            first_row.query_embs, 
            torch.Tensor(exp_embds)])
        # apply weighting to the query embeddings
        if self.mean_cos_weight or self.idf_weight or self.probIDF_weight or self.ictf_weight:
            # we are using mean_cos weighting?
            weights = torch.cat([ 
                torch.ones(len(first_row.query_embs)),
                self.beta * torch.Tensor(scores)]
            )
        else:
            weights = torch.cat([ 
                torch.ones(len(first_row.query_embs)),
                torch.full(self.exp_terms, self.beta)]
            )
        
        rtr = pd.DataFrame([
            [first_row.qid, 
             first_row.docno,
             first_row.query, 
#              first_row.doc_embs,
             newemb, 
             toks, 
             weights ]], columns=["qid","docno", "query", "query_embs","query_toks", "query_weights"])
        return rtr
        
#         ["qid","query",'docno','query_toks','query_embs']
    def transform(self, topics_and_docs):
        # some validation of the input
        required = ["qid", "query", "docid","docno", "query_embs"]
        for col in required:
            assert col in topics_and_docs.columns
        #restore the docid column if missing
        if "docid" not in topics_and_docs:
#             topics_and_docs["docid"] = topics_and_docs.docno.astype("int").values
            topics_and_docs["docid"] = topics_and_docs.docid.astype("int").values
        rtr = []
        for qid, res in topics_and_docs.groupby("qid"):
            new_query_df = self.transform_query(res)     
            if self.return_docs:
                new_query_df = res[["qid", "docno", "docid","doc_embs"]].merge(new_query_df, on=["qid"])
                
                new_query_df = new_query_df.rename(columns={'docno_x':'docno'})
            rtr.append(new_query_df)
        return pd.concat(rtr)

The experiment is run as follows,

pipeE2E_psg = pytcolbert.query_encoder() >> BM25 >> pt.text.sliding(prepend_title=False ) >> doc_encoder(pytcolbert ) >> scorer(pytcolbert)
pipePRF_rerank = pipeE2E_psg >> ColBERTPRF_docencoded(k=24, exp_terms=10, idf_weight=True, beta=1,fb_docs=3,return_docs=True)
bm25_prf_rerank = (pipePRF_rerank >> scorer(pytcolbert)>>pt.text.max_passage())%1000

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants