From c4fac28233dfbaf285b5c5ca25add3c19985dac0 Mon Sep 17 00:00:00 2001 From: Bongni Date: Thu, 29 Aug 2024 17:06:02 +0200 Subject: [PATCH 1/4] Added time measurement to the faiss adapter. --- afsl/adapters/faiss.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/afsl/adapters/faiss.py b/afsl/adapters/faiss.py index d6fa254..a2885db 100644 --- a/afsl/adapters/faiss.py +++ b/afsl/adapters/faiss.py @@ -2,6 +2,7 @@ from afsl.acquisition_functions import AcquisitionFunction, Targeted import faiss # type: ignore import torch +import time import concurrent.futures import numpy as np from afsl import ActiveDataLoader @@ -104,8 +105,10 @@ def batch_search( assert d == self.index.d mean_queries = np.mean(queries, axis=1) + t_start = time.time() faiss.omp_set_num_threads(threads) # type: ignore D, I, V = self.index.search_and_reconstruct(mean_queries, k or self.index.ntotal) # type: ignore + t_faiss = time.time() - t_start if self.only_faiss: return D[:, :N], I[:, :N], V[:, :N] @@ -131,6 +134,7 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: np.array(V[i][sub_indexes]), ) + t_start = time.time() resulting_values = [] resulting_indices = [] resulting_embeddings = [] @@ -143,8 +147,13 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: resulting_values.append(values) resulting_indices.append(indices) resulting_embeddings.append(embeddings) + t_afsl = time.time() - t_start return ( np.array(resulting_values), np.array(resulting_indices), np.array(resulting_embeddings), + { + "faiss": t_faiss, + "afsl": t_afsl, + } ) From 71e1d62593d55140001a66593f47c980b5d60107 Mon Sep 17 00:00:00 2001 From: Bongni Date: Thu, 29 Aug 2024 17:14:46 +0200 Subject: [PATCH 2/4] Formatted with black. --- afsl/adapters/faiss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/afsl/adapters/faiss.py b/afsl/adapters/faiss.py index a2885db..4e68936 100644 --- a/afsl/adapters/faiss.py +++ b/afsl/adapters/faiss.py @@ -155,5 +155,5 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: { "faiss": t_faiss, "afsl": t_afsl, - } + }, ) From 0cd16a240b0e8943f7ca792c537bb5060ec6a2ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=BCbotter?= Date: Thu, 29 Aug 2024 18:31:04 +0200 Subject: [PATCH 3/4] fixes --- afsl/adapters/faiss.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/afsl/adapters/faiss.py b/afsl/adapters/faiss.py index 4e68936..c6d5491 100644 --- a/afsl/adapters/faiss.py +++ b/afsl/adapters/faiss.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import NamedTuple, Tuple from afsl.acquisition_functions import AcquisitionFunction, Targeted import faiss # type: ignore import torch @@ -20,6 +20,13 @@ def __getitem__(self, index) -> torch.Tensor: return self.data[index] +class RetrievalTime(NamedTuple): + faiss: float + """Time spent with Faiss retrieval.""" + afsl: float + """Additional time spent with AFSL.""" + + class Retriever: """ Adapter for the [Faiss](https://github.com/facebookresearch/faiss) library. @@ -62,7 +69,7 @@ def search( k: int | None, mean_pooling: bool = False, threads: int = 1, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]: r""" :param query: Query embedding (of shape $m \times d$), comprised of $m$ individual embeddings. :param N: Number of results to return. @@ -70,16 +77,16 @@ def search( :param mean_pooling: Whether to use the mean of the query embeddings. :param threads: Number of threads to use. - :return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), and array of corresponding embeddings (of shape $N \times d$). + :return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), array of corresponding embeddings (of shape $N \times d$), retrieval time. """ - D, I, V = self.batch_search( + D, I, V, retrieval_time = self.batch_search( queries=np.array([query]), N=N, k=k, mean_pooling=mean_pooling, threads=threads, ) - return D[0], I[0], V[0] + return D[0], I[0], V[0], retrieval_time def batch_search( self, @@ -88,7 +95,7 @@ def batch_search( k: int | None = None, mean_pooling: bool = False, threads: int = 1, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]: r""" :param queries: $n$ query embeddings (of combined shape $n \times m \times d$), each comprised of $m$ individual embeddings. :param N: Number of results to return. @@ -96,7 +103,7 @@ def batch_search( :param mean_pooling: Whether to use the mean of the query embeddings. :param threads: Number of threads to use. - :return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), and array of corresponding embeddings (of shape $n \times N \times d$). + :return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), array of corresponding embeddings (of shape $n \times N \times d$), retrieval time. """ assert k is None or k >= N @@ -111,7 +118,8 @@ def batch_search( t_faiss = time.time() - t_start if self.only_faiss: - return D[:, :N], I[:, :N], V[:, :N] + retrieval_time = RetrievalTime(faiss=t_faiss, afsl=0) + return D[:, :N], I[:, :N], V[:, :N], retrieval_time def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: dataset = Dataset(torch.tensor(V[i])) @@ -148,12 +156,10 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: resulting_indices.append(indices) resulting_embeddings.append(embeddings) t_afsl = time.time() - t_start + retrieval_time = RetrievalTime(faiss=t_faiss, afsl=t_afsl) return ( np.array(resulting_values), np.array(resulting_indices), np.array(resulting_embeddings), - { - "faiss": t_faiss, - "afsl": t_afsl, - }, + retrieval_time, ) From a95597cdbb53c5596e63ac49853e061a07c47c30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=BCbotter?= Date: Thu, 29 Aug 2024 18:34:27 +0200 Subject: [PATCH 4/4] fix pyright errors --- examples/fine_tuning/cifar_100/experiment.py | 1 - examples/fine_tuning/mnist/experiment.py | 1 - examples/fine_tuning/training.py | 2 +- requirements.txt | 2 +- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/fine_tuning/cifar_100/experiment.py b/examples/fine_tuning/cifar_100/experiment.py index 65a94df..109d360 100644 --- a/examples/fine_tuning/cifar_100/experiment.py +++ b/examples/fine_tuning/cifar_100/experiment.py @@ -166,7 +166,6 @@ def experiment( faiss_index_path=faiss_index_path, target_embeddings=target_embeddings, ) - wandb.finish() def main(args): diff --git a/examples/fine_tuning/mnist/experiment.py b/examples/fine_tuning/mnist/experiment.py index d0a02a8..2f7a12e 100644 --- a/examples/fine_tuning/mnist/experiment.py +++ b/examples/fine_tuning/mnist/experiment.py @@ -146,7 +146,6 @@ def experiment( reset_parameters=RESET_PARAMS, use_best_model=USE_BEST_MODEL, ) - wandb.finish() def main(args): diff --git a/examples/fine_tuning/training.py b/examples/fine_tuning/training.py index 0f5cae9..33768bc 100644 --- a/examples/fine_tuning/training.py +++ b/examples/fine_tuning/training.py @@ -112,7 +112,7 @@ def train_loop( target_embeddings ) # ensure target set is reset to correct length query = acquisition_function.get_target().cpu().numpy() - _, _batch_indices, _ = retriever.search( + _, _batch_indices, _, _ = retriever.search( query=query, N=query_batch_size, k=100 * query_batch_size ) batch_indices = torch.tensor(_batch_indices) diff --git a/requirements.txt b/requirements.txt index 44f6c27..0ff3f2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ pytest~=8.3 torch~=2.4 torchvision~=0.19 tqdm~=4.66 -wandb +wandb~=0.17