Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Time #60

Merged
merged 4 commits into from
Aug 29, 2024
Merged

Time #60

Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions afsl/adapters/faiss.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Tuple
from typing import NamedTuple, Tuple
from afsl.acquisition_functions import AcquisitionFunction, Targeted
import faiss # type: ignore
import torch
import time
import concurrent.futures
import numpy as np
from afsl import ActiveDataLoader
Expand All @@ -19,6 +20,13 @@ def __getitem__(self, index) -> torch.Tensor:
return self.data[index]


class RetrievalTime(NamedTuple):
faiss: float
"""Time spent with Faiss retrieval."""
afsl: float
"""Additional time spent with AFSL."""


class Retriever:
"""
Adapter for the [Faiss](https://github.com/facebookresearch/faiss) library.
Expand Down Expand Up @@ -61,24 +69,24 @@ def search(
k: int | None,
mean_pooling: bool = False,
threads: int = 1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]:
r"""
:param query: Query embedding (of shape $m \times d$), comprised of $m$ individual embeddings.
:param N: Number of results to return.
:param k: Number of results to pre-sample with Faiss. Does not pre-sample if set to `None`.
:param mean_pooling: Whether to use the mean of the query embeddings.
:param threads: Number of threads to use.

:return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), and array of corresponding embeddings (of shape $N \times d$).
:return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), array of corresponding embeddings (of shape $N \times d$), retrieval time.
"""
D, I, V = self.batch_search(
D, I, V, retrieval_time = self.batch_search(
queries=np.array([query]),
N=N,
k=k,
mean_pooling=mean_pooling,
threads=threads,
)
return D[0], I[0], V[0]
return D[0], I[0], V[0], retrieval_time

def batch_search(
self,
Expand All @@ -87,15 +95,15 @@ def batch_search(
k: int | None = None,
mean_pooling: bool = False,
threads: int = 1,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]:
r"""
:param queries: $n$ query embeddings (of combined shape $n \times m \times d$), each comprised of $m$ individual embeddings.
:param N: Number of results to return.
:param k: Number of results to pre-sample with Faiss. Does not pre-sample if set to `None`.
:param mean_pooling: Whether to use the mean of the query embeddings.
:param threads: Number of threads to use.

:return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), and array of corresponding embeddings (of shape $n \times N \times d$).
:return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), array of corresponding embeddings (of shape $n \times N \times d$), retrieval time.
"""
assert k is None or k >= N

Expand All @@ -104,11 +112,14 @@ def batch_search(
assert d == self.index.d
mean_queries = np.mean(queries, axis=1)

t_start = time.time()
faiss.omp_set_num_threads(threads) # type: ignore
D, I, V = self.index.search_and_reconstruct(mean_queries, k or self.index.ntotal) # type: ignore
t_faiss = time.time() - t_start

if self.only_faiss:
return D[:, :N], I[:, :N], V[:, :N]
retrieval_time = RetrievalTime(faiss=t_faiss, afsl=0)
return D[:, :N], I[:, :N], V[:, :N], retrieval_time

def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
dataset = Dataset(torch.tensor(V[i]))
Expand All @@ -131,6 +142,7 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
np.array(V[i][sub_indexes]),
)

t_start = time.time()
resulting_values = []
resulting_indices = []
resulting_embeddings = []
Expand All @@ -143,8 +155,11 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
resulting_values.append(values)
resulting_indices.append(indices)
resulting_embeddings.append(embeddings)
t_afsl = time.time() - t_start
retrieval_time = RetrievalTime(faiss=t_faiss, afsl=t_afsl)
return (
np.array(resulting_values),
np.array(resulting_indices),
np.array(resulting_embeddings),
retrieval_time,
)
Loading