Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include Vespa Lexical Search as an option to BEIR benchmark #76

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3d45c01
include pyvespa dependency
thigm85 Jan 24, 2022
27b3cfd
include pycharm config files on gitignore
thigm85 Jan 24, 2022
5674034
create a class for vespa lexical search
thigm85 Jan 24, 2022
064c34e
basic tests for vespa lexical search
thigm85 Jan 24, 2022
d764954
include vespa lexical search as a possivle argument in the evaluator
thigm85 Jan 24, 2022
eac2910
add vespa benchmark script
thigm85 Jan 24, 2022
0c54ada
increase timeout
thigm85 Jan 24, 2022
4497ece
Merge branch 'tgm/vespa_lexical_search' into tgm/add-vespa-lexical-be…
thigm85 Jan 24, 2022
8197a73
include deployment parameters. Fix result object.
thigm85 Jan 24, 2022
14ab0c5
include remove app method
thigm85 Jan 24, 2022
caf5131
use remove app method on the unit tests
thigm85 Jan 24, 2022
23ca3f8
Merge branch 'tgm/vespa_lexical_search' into tgm/add-vespa-lexical-be…
thigm85 Jan 24, 2022
2ba1211
use remove method
thigm85 Jan 24, 2022
8d94af4
get container by name
thigm85 Jan 24, 2022
0a77411
Merge branch 'tgm/vespa_lexical_search' into tgm/add-vespa-lexical-be…
thigm85 Jan 24, 2022
ed06a11
add progress info
thigm85 Jan 25, 2022
4faf388
include all datasets
thigm85 Jan 25, 2022
7ae1a29
Include tenacity
thigm85 Jan 30, 2022
5be07ea
add retry strategy
thigm85 Jan 30, 2022
34cd6c5
update script
thigm85 Jan 31, 2022
88f0002
Merge pull request #3 from thigm85/tgm/vespa-lexical-experiment
thigm85 Feb 15, 2022
066ef2d
add issue link
thigm85 Feb 15, 2022
e5b24b6
process queries in parallel
thigm85 Feb 16, 2022
4c5ac5c
pre-process queries
thigm85 Feb 18, 2022
48a7c14
improve feeding and output information
thigm85 Feb 18, 2022
ac640ee
Merge pull request #4 from thigm85/tgm/vespa-lexical-experiment
lesters Feb 21, 2022
1061122
exclude cases where queries is returned as hits
thigm85 Feb 24, 2022
80f4f3a
add option to not exclude the dataset and not remove the app
thigm85 Mar 1, 2022
7ee7962
document benchmark script
thigm85 Mar 1, 2022
713887c
add split_type and fix deployment parameters
thigm85 Mar 3, 2022
b08a90e
expose timeout and async connections. Continue in case of empty hits.
thigm85 Mar 9, 2022
cc51233
msmarco uses dev set
thigm85 Mar 9, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,6 @@ dmypy.json

# Pyre type checker
.pyre/

# PyCharm
.idea
3 changes: 2 additions & 1 deletion beir/retrieval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
from .search.dense import DenseRetrievalExactSearch as DRES
from .search.dense import DenseRetrievalFaissSearch as DRFS
from .search.lexical import BM25Search as BM25
from .search.lexical.vespa_search import VespaLexicalSearch
from .search.sparse import SparseSearch as SS
from .custom_metrics import mrr, recall_cap, hole, top_k_accuracy

logger = logging.getLogger(__name__)

class EvaluateRetrieval:

def __init__(self, retriever: Union[Type[DRES], Type[DRFS], Type[BM25], Type[SS]] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"):
def __init__(self, retriever: Union[Type[DRES], Type[DRFS], Type[BM25], Type[SS], VespaLexicalSearch] = None, k_values: List[int] = [1, 3, 5, 10, 100, 1000], score_function: str = "cos_sim"):
self.k_values = k_values
self.top_k = max(k_values)
self.retriever = retriever
Expand Down
303 changes: 303 additions & 0 deletions beir/retrieval/search/lexical/vespa_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
import shutil
import re
from statistics import mean, median
from collections import Counter
from typing import Dict, Optional
from vespa.application import Vespa
from vespa.package import ApplicationPackage, Field, FieldSet, RankProfile, QueryField
from vespa.query import QueryModel, OR, RankProfile as Ranking, WeakAnd
from vespa.deployment import VespaDocker
from tenacity import retry, wait_exponential, stop_after_attempt, RetryError

REPLACE_SYMBOLS = ["(", ")", " -", " +"]
QUOTES = [
"\u0022", # quotation mark (")
"\u0027", # apostrophe (')
"\u00ab", # left-pointing double-angle quotation mark
"\u00bb", # right-pointing double-angle quotation mark
"\u2018", # left single quotation mark
"\u2019", # right single quotation mark
"\u201a", # single low-9 quotation mark
"\u201b", # single high-reversed-9 quotation mark
"\u201c", # left double quotation mark
"\u201d", # right double quotation mark
"\u201e", # double low-9 quotation mark
"\u201f", # double high-reversed-9 quotation mark
"\u2039", # single left-pointing angle quotation mark
"\u203a", # single right-pointing angle quotation mark
"\u300c", # left corner bracket
"\u300d", # right corner bracket
"\u300e", # left white corner bracket
"\u300f", # right white corner bracket
"\u301d", # reversed double prime quotation mark
"\u301e", # double prime quotation mark
"\u301f", # low double prime quotation mark
"\ufe41", # presentation form for vertical left corner bracket
"\ufe42", # presentation form for vertical right corner bracket
"\ufe43", # presentation form for vertical left corner white bracket
"\ufe44", # presentation form for vertical right corner white bracket
"\uff02", # fullwidth quotation mark
"\uff07", # fullwidth apostrophe
"\uff62", # halfwidth left corner bracket
"\uff63", # halfwidth right corner bracket
]
REPLACE_SYMBOLS.extend(QUOTES)


def replace_symbols(x):
for symbol in REPLACE_SYMBOLS:
x = x.replace(symbol, "")
return x


class VespaLexicalSearch:
def __init__(
self,
application_name: str,
match_phase: str = "or",
rank_phase: str = "bm25",
deployment_parameters: Optional[Dict] = None,
initialize: bool = True,
):
self.results = {}
self.application_name = application_name.replace("-", "")
assert match_phase in [
"or",
"weak_and",
], "'match_phase' should be either 'or' or 'weak_and'"
self.match_phase = match_phase
assert rank_phase in [
"bm25",
"native_rank",
], "'rank_phase' should be either 'bm25' or 'native_rank'"
self.rank_phase = rank_phase
self.deployment_parameters = deployment_parameters
self.initialize = initialize
self.vespa_docker = None
if self.initialize:
self.app = self.initialise()
assert (
self.app.get_application_status().status_code == 200
), "Application status different than 200."
else:
self.vespa_docker = VespaDocker.from_container_name_or_id(
self.application_name
)
assert self.deployment_parameters is not None, (
"if 'initialize' is set to false, 'deployment_parameters' should contain Vespa "
"connection parameters such as 'url' and 'port'"
)
self.app = Vespa(**self.deployment_parameters)
assert (
self.app.get_application_status().status_code == 200
), "Application status different than 200."

def initialise(self):
#
# Create Vespa application package
#
app_package = ApplicationPackage(name=self.application_name)
app_package.schema.add_fields(
Field(name="id", type="string", indexing=["attribute", "summary"]),
Field(
name="title",
type="string",
indexing=["index"],
index="enable-bm25",
),
Field(
name="body",
type="string",
indexing=["index"],
index="enable-bm25",
),
)
app_package.schema.add_field_set(
FieldSet(name="default", fields=["title", "body"])
)
app_package.schema.add_rank_profile(
rank_profile=RankProfile(
name="bm25", first_phase="bm25(title) + bm25(body)"
)
)
app_package.schema.add_rank_profile(
rank_profile=RankProfile(
name="native_rank", first_phase="nativeRank(title,body)"
)
)
app_package.query_profile.add_fields(QueryField(name="maxHits", value=10000))
#
# Deploy application
#
if not self.deployment_parameters:
self.deployment_parameters = {"port": 8089, "container_memory": "12G"}
self.vespa_docker = VespaDocker(**self.deployment_parameters)
app = self.vespa_docker.deploy(application_package=app_package)
app.delete_all_docs(
content_cluster_name=self.application_name + "_content",
schema=self.application_name,
)
return app

def remove_app(self):
if self.vespa_docker:
shutil.rmtree(
self.application_name, ignore_errors=True
) # remove application package folder
self.vespa_docker.container.stop(timeout=600) # stop docker container
self.vespa_docker.container.remove() # rm docker container

@retry(wait=wait_exponential(multiplier=1), stop=stop_after_attempt(10))
def send_query_batch(
self, query_batch, query_model, hits, timeout=100, async_connections=50
):
query_results = self.app.query_batch(
query_batch=query_batch,
query_model=query_model,
connections=async_connections,
total_timeout=timeout * len(query_batch),
hits=hits,
**{"timeout": str(timeout) + " s", "ranking.softtimeout.enable": "false"}
)
return query_results

def process_queries(
self, query_ids, queries, query_model, hits, batch_size, timeout=100, async_connections=50
):
results = {}
assert len(query_ids) == len(
queries
), "There must be one query_id for each query."
query_id_batches = [
query_ids[i : i + batch_size] for i in range(0, len(query_ids), batch_size)
]
query_batches = [
queries[i : i + batch_size] for i in range(0, len(queries), batch_size)
]
for idx, (query_id_batch, query_batch) in enumerate(
zip(query_id_batches, query_batches)
):
print(
"{}, {}, {}: {}/{}".format(
self.application_name,
self.match_phase,
self.rank_phase,
idx,
len(query_batches),
)
)
try:
query_results = self.send_query_batch(
query_batch=query_batch,
query_model=query_model,
hits=hits,
timeout=timeout,
async_connections=async_connections
)
number_hits = [x.number_documents_retrieved for x in query_results]
status_code_summary = Counter([x.status_code for x in query_results])
print(
"Sucessfull queries: {}/{}\nDocuments retrieved. Min: {}, Max: {}, Mean: {}, Median: {}.".format(
status_code_summary[200],
len(query_batch),
min(number_hits),
max(number_hits),
round(mean(number_hits), 2),
round(median(number_hits), 2),
)
)
except RetryError:
continue
for (query_id, query_result) in zip(query_id_batch, query_results):
scores = {}
try:
if query_result.hits:
for hit in query_result.hits:
corpus_id = hit["fields"]["id"]
if (
corpus_id != query_id
): # See https://github.com/UKPLab/beir/issues/72
scores[corpus_id] = hit["relevance"]
except KeyError:
continue
results[query_id] = scores
return results

def search(
self,
corpus: Dict[str, Dict[str, str]],
queries: Dict[str, str],
top_k: int,
*args,
**kwargs
) -> Dict[str, Dict[str, float]]:

if self.initialize:
_ = self.index(corpus)

# retrieve results from BM25
query_ids = list(queries.keys())
queries = [queries[qid] for qid in query_ids]

queries = [
re.sub(" +", " ", replace_symbols(x)).strip() for x in queries
] # remove quotes and double spaces from queries

if self.match_phase == "or":
match_phase = OR()
elif self.match_phase == "weak_and":
match_phase = WeakAnd(hits=top_k)
else:
ValueError("'match_phase' should be either 'or' or 'weak_and'")

if self.rank_phase not in ["bm25", "native_rank"]:
ValueError("'rank_phase' should be either 'bm25' or 'native_rank'")

query_model = QueryModel(
name=self.match_phase + "_" + self.rank_phase,
match_phase=match_phase,
rank_profile=Ranking(name=self.rank_phase, list_features=False),
)

self.results = self.process_queries(
query_ids=query_ids,
queries=queries,
query_model=query_model,
hits=top_k,
batch_size=1000,
timeout="100 s",
)
return self.results

@retry(wait=wait_exponential(multiplier=1), stop=stop_after_attempt(10))
def send_feed_batch(self, feed_batch, total_timeout=10000):
feed_results = self.app.feed_batch(
batch=feed_batch, total_timeout=total_timeout
)
return feed_results

def index(self, corpus: Dict[str, Dict[str, str]], batch_size=1000):
batch_feed = [
{
"id": idx,
"fields": {
"id": idx,
"title": corpus[idx].get("title", None),
"body": corpus[idx].get("text", None),
},
}
for idx in list(corpus.keys())
]
mini_batches = [
batch_feed[i : i + batch_size]
for i in range(0, len(batch_feed), batch_size)
]
for idx, feed_batch in enumerate(mini_batches):
feed_results = self.send_feed_batch(feed_batch=feed_batch)
status_code_summary = Counter([x.status_code for x in feed_results])
print(
"Successful documents fed: {}/{}.\nBatch progress: {}/{}.".format(
status_code_summary[200], len(feed_batch), idx, len(mini_batches)
)
)
return 0
65 changes: 65 additions & 0 deletions beir/test_retrieval_lexical_vespa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import unittest
from beir.retrieval.search.lexical.vespa_search import VespaLexicalSearch
from beir.retrieval.evaluation import EvaluateRetrieval


class TestVespaSearch(unittest.TestCase):
def setUp(self) -> None:
self.application_name = "vespa_test"
self.corpus = {
"1": {"title": "this is a title 1", "text": "this is text 1"},
"2": {"title": "this is a title 2", "text": "this is text 2"},
"3": {"title": "this is a title 3", "text": "this is text 3"},
}
self.queries = {"1": "this is query 1", "2": "this is query 2"}

def test_or_bm25(self):
self.model = VespaLexicalSearch(
application_name=self.application_name, initialize=True
)
retriever = EvaluateRetrieval(self.model)
results = retriever.retrieve(corpus=self.corpus, queries=self.queries)
self.assertEqual({"1", "2"}, set(results.keys()))
for query_id in results.keys():
self.assertEqual({"1", "2", "3"}, set(results[query_id].keys()))

def test_or_native_rank(self):
self.model = VespaLexicalSearch(
application_name=self.application_name,
initialize=True,
match_phase="or",
rank_phase="native_rank",
)
retriever = EvaluateRetrieval(self.model)
results = retriever.retrieve(corpus=self.corpus, queries=self.queries)
self.assertEqual({"1", "2"}, set(results.keys()))
for query_id in results.keys():
self.assertEqual({"1", "2", "3"}, set(results[query_id].keys()))

def test_weakand_bm25(self):
self.model = VespaLexicalSearch(
application_name=self.application_name,
initialize=True,
match_phase="weak_and",
)
retriever = EvaluateRetrieval(self.model)
results = retriever.retrieve(corpus=self.corpus, queries=self.queries)
self.assertEqual({"1", "2"}, set(results.keys()))
for query_id in results.keys():
self.assertEqual({"1", "2", "3"}, set(results[query_id].keys()))

def test_weakand_native_rank(self):
self.model = VespaLexicalSearch(
application_name=self.application_name,
initialize=True,
match_phase="weak_and",
rank_phase="native_rank",
)
retriever = EvaluateRetrieval(self.model)
results = retriever.retrieve(corpus=self.corpus, queries=self.queries)
self.assertEqual({"1", "2"}, set(results.keys()))
for query_id in results.keys():
self.assertEqual({"1", "2", "3"}, set(results[query_id].keys()))

def tearDown(self) -> None:
self.model.remove_app()
Loading