diff --git a/afsl/adapters/faiss.py b/afsl/adapters/faiss.py index d6fa254..c6d5491 100644 --- a/afsl/adapters/faiss.py +++ b/afsl/adapters/faiss.py @@ -1,7 +1,8 @@ -from typing import Tuple +from typing import NamedTuple, Tuple from afsl.acquisition_functions import AcquisitionFunction, Targeted import faiss # type: ignore import torch +import time import concurrent.futures import numpy as np from afsl import ActiveDataLoader @@ -19,6 +20,13 @@ def __getitem__(self, index) -> torch.Tensor: return self.data[index] +class RetrievalTime(NamedTuple): + faiss: float + """Time spent with Faiss retrieval.""" + afsl: float + """Additional time spent with AFSL.""" + + class Retriever: """ Adapter for the [Faiss](https://github.com/facebookresearch/faiss) library. @@ -61,7 +69,7 @@ def search( k: int | None, mean_pooling: bool = False, threads: int = 1, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]: r""" :param query: Query embedding (of shape $m \times d$), comprised of $m$ individual embeddings. :param N: Number of results to return. @@ -69,16 +77,16 @@ def search( :param mean_pooling: Whether to use the mean of the query embeddings. :param threads: Number of threads to use. - :return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), and array of corresponding embeddings (of shape $N \times d$). + :return: Array of acquisition values (of length $N$), array of selected indices (of length $N$), array of corresponding embeddings (of shape $N \times d$), retrieval time. """ - D, I, V = self.batch_search( + D, I, V, retrieval_time = self.batch_search( queries=np.array([query]), N=N, k=k, mean_pooling=mean_pooling, threads=threads, ) - return D[0], I[0], V[0] + return D[0], I[0], V[0], retrieval_time def batch_search( self, @@ -87,7 +95,7 @@ def batch_search( k: int | None = None, mean_pooling: bool = False, threads: int = 1, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, RetrievalTime]: r""" :param queries: $n$ query embeddings (of combined shape $n \times m \times d$), each comprised of $m$ individual embeddings. :param N: Number of results to return. @@ -95,7 +103,7 @@ def batch_search( :param mean_pooling: Whether to use the mean of the query embeddings. :param threads: Number of threads to use. - :return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), and array of corresponding embeddings (of shape $n \times N \times d$). + :return: Array of acquisition values (of shape $n \times N$), array of selected indices (of shape $n \times N$), array of corresponding embeddings (of shape $n \times N \times d$), retrieval time. """ assert k is None or k >= N @@ -104,11 +112,14 @@ def batch_search( assert d == self.index.d mean_queries = np.mean(queries, axis=1) + t_start = time.time() faiss.omp_set_num_threads(threads) # type: ignore D, I, V = self.index.search_and_reconstruct(mean_queries, k or self.index.ntotal) # type: ignore + t_faiss = time.time() - t_start if self.only_faiss: - return D[:, :N], I[:, :N], V[:, :N] + retrieval_time = RetrievalTime(faiss=t_faiss, afsl=0) + return D[:, :N], I[:, :N], V[:, :N], retrieval_time def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: dataset = Dataset(torch.tensor(V[i])) @@ -131,6 +142,7 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: np.array(V[i][sub_indexes]), ) + t_start = time.time() resulting_values = [] resulting_indices = [] resulting_embeddings = [] @@ -143,8 +155,11 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: resulting_values.append(values) resulting_indices.append(indices) resulting_embeddings.append(embeddings) + t_afsl = time.time() - t_start + retrieval_time = RetrievalTime(faiss=t_faiss, afsl=t_afsl) return ( np.array(resulting_values), np.array(resulting_indices), np.array(resulting_embeddings), + retrieval_time, ) diff --git a/examples/fine_tuning/cifar_100/experiment.py b/examples/fine_tuning/cifar_100/experiment.py index 65a94df..109d360 100644 --- a/examples/fine_tuning/cifar_100/experiment.py +++ b/examples/fine_tuning/cifar_100/experiment.py @@ -166,7 +166,6 @@ def experiment( faiss_index_path=faiss_index_path, target_embeddings=target_embeddings, ) - wandb.finish() def main(args): diff --git a/examples/fine_tuning/mnist/experiment.py b/examples/fine_tuning/mnist/experiment.py index d0a02a8..2f7a12e 100644 --- a/examples/fine_tuning/mnist/experiment.py +++ b/examples/fine_tuning/mnist/experiment.py @@ -146,7 +146,6 @@ def experiment( reset_parameters=RESET_PARAMS, use_best_model=USE_BEST_MODEL, ) - wandb.finish() def main(args): diff --git a/examples/fine_tuning/training.py b/examples/fine_tuning/training.py index 0f5cae9..33768bc 100644 --- a/examples/fine_tuning/training.py +++ b/examples/fine_tuning/training.py @@ -112,7 +112,7 @@ def train_loop( target_embeddings ) # ensure target set is reset to correct length query = acquisition_function.get_target().cpu().numpy() - _, _batch_indices, _ = retriever.search( + _, _batch_indices, _, _ = retriever.search( query=query, N=query_batch_size, k=100 * query_batch_size ) batch_indices = torch.tensor(_batch_indices) diff --git a/requirements.txt b/requirements.txt index 44f6c27..0ff3f2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,4 @@ pytest~=8.3 torch~=2.4 torchvision~=0.19 tqdm~=4.66 -wandb +wandb~=0.17