Skip to content

Commit

Permalink
splade
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmacavaney committed Aug 30, 2024
1 parent c99d18f commit c483c90
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 17 deletions.
28 changes: 13 additions & 15 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,48 @@
class TestBasic(unittest.TestCase):

def setUp(self):
self.factory = pyt_splade.Splade(device='cpu')
self.splade = pyt_splade.Splade(device='cpu')

def test_transformer_indexing(self):
df = (self.factory.indexing() >> pyt_splade.toks2doc()).transform_iter([{'docno' : 'd1', 'text' : 'hello there'}])
df = (self.splade.doc_encoder() >> pyt_splade.toks2doc())([{'docno' : 'd1', 'text' : 'hello there'}])
self.assertTrue('there there' in df.iloc[0].text)
df = self.factory.indexing().transform_iter([
df = self.splade.doc_encoder()([
{'docno' : 'd1', 'text' : 'hello there'},
{'docno' : 'd1', 'text' : ''}, #empty
{'docno' : 'd1', 'text' : 'hello hello hello hello hello there'}])

def test_transformer_querying(self):
q = self.factory.query()
q = self.splade.query_encoder()
df = q.transform_iter([{'qid' : 'q1', 'query' : 'chemical reactions'}])
print(df.iloc[0].query)
self.assertTrue('#combine' in df.iloc[0].query)
self.assertTrue('query_toks' in df.columns)

def test_transformer_empty_query(self):
q = self.factory.query()
q = self.splade.query_encoder()
res = q(pd.DataFrame([], columns=['qid', 'query']))
self.assertEqual(['qid', 'query_0', 'query'], list(res.columns))
self.assertEqual(['qid', 'query', 'query_toks'], list(res.columns))

def test_transformer_empty_doc(self):
d = self.factory.indexing()
d = self.splade.doc_encoder()
res = d(pd.DataFrame([], columns=['docno', 'text']))
self.assertEqual(['docno', 'text', 'toks'], list(res.columns))

def test_model_output_one_dim_non_zero_rep(self):
import torch
one_dim_non_zero = torch.zeros(1, self.factory.model.output_dim)
one_dim_non_zero = torch.zeros(1, self.splade.model.output_dim)
one_dim_non_zero[0][0] = 1.
mock_return = {
"d_rep": one_dim_non_zero,
"q_rep": one_dim_non_zero,
}
factory = pyt_splade.SpladeFactory(device='cpu')
mock_model = MagicMock(return_value=mock_return)
factory.model = mock_model
splade = pyt_splade.Splade(MagicMock(return_value=mock_return), tokenizer=self.splade.tokenizer, device='cpu')

res = factory.indexing()(
res = splade.doc_encoder()(
[{'docno' : 'd1', 'text' : 'hello there'}]
)
self.assertEqual(['docno', 'text', 'toks'], list(res.columns))

res = factory.query()(
res = splade.query_encoder()(
[{'qid' : 'd1', 'query' : 'chemical reactions'}]
)
self.assertEqual(['qid', 'query_0', 'query'], list(res.columns))
self.assertEqual(['qid', 'query', 'query_toks'], list(res.columns))
4 changes: 2 additions & 2 deletions tests/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
class TestScorer(unittest.TestCase):

def setUp(self):
self.factory = pyt_splade.Splade(device='cpu')
self.splade = pyt_splade.Splade(device='cpu')

def test_scorer(self):
df = self.factory.scorer()([
df = self.splade.scorer()([
{'qid': '0', 'query': 'chemical reactions', 'docno' : 'd1', 'text' : 'hello there'},
{'qid': '0', 'query': 'chemical reactions', 'docno' : 'd2', 'text' : 'chemistry society'},
{'qid': '1', 'query': 'hello', 'docno' : 'd1', 'text' : 'hello there'},
Expand Down

0 comments on commit c483c90

Please sign in to comment.