From 98eb8fe7c09c72c200f1c6b5b708c3bd2043be5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=BCbotter?= Date: Wed, 28 Aug 2024 15:38:06 +0200 Subject: [PATCH 1/2] return proper acquisition function value --- afsl/acquisition_functions/__init__.py | 41 +++++++++++++------------- afsl/active_data_loader.py | 6 ++-- afsl/adapters/faiss.py | 4 +-- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/afsl/acquisition_functions/__init__.py b/afsl/acquisition_functions/__init__.py index eed9c2d..9e03455 100644 --- a/afsl/acquisition_functions/__init__.py +++ b/afsl/acquisition_functions/__init__.py @@ -100,7 +100,7 @@ def select( model: M, dataset: Dataset, device: torch.device | None = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Selects the next batch. @@ -108,7 +108,7 @@ def select( :param model: Model used for data selection. :param dataset: Inputs (shape $n \times d$) to be selected from. :param device: Device used for computation of the acquisition function. - :return: Indices of the newly selected batch. + :return: Indices of the newly selected batch and corresponding values of the acquisition function. """ pass @@ -141,7 +141,7 @@ def select( model: M, dataset: Dataset, device: torch.device | None = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: return BatchAcquisitionFunction._select( compute_fn=self.compute, batch_size=batch_size, @@ -163,7 +163,7 @@ def _select( mini_batch_size: int, num_workers: int, subsample: bool, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: indexed_dataset = _IndexedDataset(dataset) data_loader = DataLoader( indexed_dataset, @@ -182,8 +182,8 @@ def _select( values = torch.cat(_values) original_indices = torch.cat(_original_indices) - _, indices = torch.topk(values, batch_size) - return original_indices[indices.cpu()] + values, indices = torch.topk(values, batch_size) + return original_indices[indices.cpu()], values.cpu() State = TypeVar("State") @@ -261,7 +261,7 @@ def selector(values: torch.Tensor) -> int: def select_from_minibatch( self, batch_size: int, model: M, data: torch.Tensor, device: torch.device | None - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Selects the next batch from the given mini batch `data`. @@ -269,17 +269,19 @@ def select_from_minibatch( :param model: Model used for data selection. :param data: Mini batch of inputs (shape $n \times d$) to be selected from. :param device: Device used for computation of the acquisition function. - :return: Indices of the newly selected batch (with respect to mini batch). + :return: Indices of the newly selected batch (with respect to mini batch) and corresponding values of the acquisition function. """ state = self.initialize(model, data, device) - indices = [] + selected_indices = [] + selected_values = [] for _ in range(batch_size): values = self.compute(state) i = self.selector(values) - indices.append(i) + selected_indices.append(i) + selected_values.append(values[i]) state = self.step(state, i) - return torch.tensor(indices) + return torch.tensor(selected_indices), torch.tensor(selected_values) def select( self, @@ -287,7 +289,7 @@ def select( model: M, dataset: Dataset, device: torch.device | None = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Selects the next batch. If `force_nonsequential` is `True`, the data is selected analogously to `BatchAcquisitionFunction.select`. Otherwise, the data is selected by hierarchical composition of data selected from mini batches. @@ -296,7 +298,7 @@ def select( :param model: Model used for data selection. :param dataset: Inputs (shape $n \times d$) to be selected from. :param device: Device used for computation of the acquisition function. - :return: Indices of the newly selected batch. + :return: Indices of the newly selected batch and corresponding values of the acquisition function. """ if self.force_nonsequential: @@ -326,7 +328,7 @@ def compute_fn( indexed_dataset = _IndexedDataset(dataset) selected_indices = range(len(dataset)) - while len(selected_indices) > batch_size: + while len(selected_indices) > batch_size: # gradually shrinks size of selected batch, until the correct size is reached data_loader = DataLoader( indexed_dataset, batch_size=self.mini_batch_size, @@ -335,16 +337,15 @@ def compute_fn( ) selected_indices = [] + selected_values = [] for data, idx in data_loader: - selected_indices.extend( - idx[self.select_from_minibatch(batch_size, model, data, device)] - .cpu() - .tolist() - ) + sub_idx, sub_val = self.select_from_minibatch(batch_size, model, data, device) + selected_indices.extend(idx[sub_idx].cpu().tolist()) + selected_values.extend(sub_val.cpu().tolist()) if self.subsample: break indexed_dataset = Subset(indexed_dataset, selected_indices) - return torch.tensor(selected_indices) + return torch.tensor(selected_indices), torch.tensor(selected_values) class EmbeddingBased(ABC): diff --git a/afsl/active_data_loader.py b/afsl/active_data_loader.py index 6c5a70a..5c82213 100644 --- a/afsl/active_data_loader.py +++ b/afsl/active_data_loader.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Generic +from typing import Generic, Tuple import torch from afsl.acquisition_functions import M, AcquisitionFunction, Targeted from afsl.acquisition_functions.undirected_vtl import UndirectedVTL @@ -130,7 +130,7 @@ def initialize( device=device, ) - def next(self, model: M | None = None) -> torch.Tensor: + def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]: r""" Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`. @@ -139,7 +139,7 @@ def next(self, model: M | None = None) -> torch.Tensor: The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`. :param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded. - :return: Indices of the selected data. + :return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`. """ return self.acquisition_function.select( diff --git a/afsl/adapters/faiss.py b/afsl/adapters/faiss.py index f676e48..84f3076 100644 --- a/afsl/adapters/faiss.py +++ b/afsl/adapters/faiss.py @@ -119,13 +119,13 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray]: if isinstance(self.acquisition_function, Targeted): self.acquisition_function.set_target(target) - sub_indexes = ActiveDataLoader( + sub_indexes, values = ActiveDataLoader( dataset=dataset, batch_size=k, acquisition_function=self.acquisition_function, device=self.device, ).next() - return np.array(I[i][sub_indexes]), np.array(V[i][sub_indexes]) + return np.array(I[i][sub_indexes]), np.array(values) resulting_indices = [] resulting_values = [] From 11ec4ed4bc326f591fe20c9f206cc9763d5abde5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=BCbotter?= Date: Wed, 28 Aug 2024 15:46:25 +0200 Subject: [PATCH 2/2] black --- afsl/acquisition_functions/__init__.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/afsl/acquisition_functions/__init__.py b/afsl/acquisition_functions/__init__.py index 9e03455..f84df9f 100644 --- a/afsl/acquisition_functions/__init__.py +++ b/afsl/acquisition_functions/__init__.py @@ -328,7 +328,9 @@ def compute_fn( indexed_dataset = _IndexedDataset(dataset) selected_indices = range(len(dataset)) - while len(selected_indices) > batch_size: # gradually shrinks size of selected batch, until the correct size is reached + while ( + len(selected_indices) > batch_size + ): # gradually shrinks size of selected batch, until the correct size is reached data_loader = DataLoader( indexed_dataset, batch_size=self.mini_batch_size, @@ -339,7 +341,9 @@ def compute_fn( selected_indices = [] selected_values = [] for data, idx in data_loader: - sub_idx, sub_val = self.select_from_minibatch(batch_size, model, data, device) + sub_idx, sub_val = self.select_from_minibatch( + batch_size, model, data, device + ) selected_indices.extend(idx[sub_idx].cpu().tolist()) selected_values.extend(sub_val.cpu().tolist()) if self.subsample: