Skip to content

Commit

Permalink
return proper acquisition function value (#56)
Browse files Browse the repository at this point in the history
* return proper acquisition function value

* black
  • Loading branch information
jonhue authored Aug 28, 2024
1 parent edf260e commit 7b4d3ea
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 24 deletions.
43 changes: 24 additions & 19 deletions afsl/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def select(
model: M,
dataset: Dataset,
device: torch.device | None = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch.
:param batch_size: Size of the batch to be selected. Needs to be smaller than `mini_batch_size`.
:param model: Model used for data selection.
:param dataset: Inputs (shape $n \times d$) to be selected from.
:param device: Device used for computation of the acquisition function.
:return: Indices of the newly selected batch.
:return: Indices of the newly selected batch and corresponding values of the acquisition function.
"""
pass

Expand Down Expand Up @@ -141,7 +141,7 @@ def select(
model: M,
dataset: Dataset,
device: torch.device | None = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
return BatchAcquisitionFunction._select(
compute_fn=self.compute,
batch_size=batch_size,
Expand All @@ -163,7 +163,7 @@ def _select(
mini_batch_size: int,
num_workers: int,
subsample: bool,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
indexed_dataset = _IndexedDataset(dataset)
data_loader = DataLoader(
indexed_dataset,
Expand All @@ -182,8 +182,8 @@ def _select(
values = torch.cat(_values)
original_indices = torch.cat(_original_indices)

_, indices = torch.topk(values, batch_size)
return original_indices[indices.cpu()]
values, indices = torch.topk(values, batch_size)
return original_indices[indices.cpu()], values.cpu()


State = TypeVar("State")
Expand Down Expand Up @@ -261,33 +261,35 @@ def selector(values: torch.Tensor) -> int:

def select_from_minibatch(
self, batch_size: int, model: M, data: torch.Tensor, device: torch.device | None
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch from the given mini batch `data`.
:param batch_size: Size of the batch to be selected. Needs to be smaller than `mini_batch_size`.
:param model: Model used for data selection.
:param data: Mini batch of inputs (shape $n \times d$) to be selected from.
:param device: Device used for computation of the acquisition function.
:return: Indices of the newly selected batch (with respect to mini batch).
:return: Indices of the newly selected batch (with respect to mini batch) and corresponding values of the acquisition function.
"""
state = self.initialize(model, data, device)

indices = []
selected_indices = []
selected_values = []
for _ in range(batch_size):
values = self.compute(state)
i = self.selector(values)
indices.append(i)
selected_indices.append(i)
selected_values.append(values[i])
state = self.step(state, i)
return torch.tensor(indices)
return torch.tensor(selected_indices), torch.tensor(selected_values)

def select(
self,
batch_size: int,
model: M,
dataset: Dataset,
device: torch.device | None = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch. If `force_nonsequential` is `True`, the data is selected analogously to `BatchAcquisitionFunction.select`.
Otherwise, the data is selected by hierarchical composition of data selected from mini batches.
Expand All @@ -296,7 +298,7 @@ def select(
:param model: Model used for data selection.
:param dataset: Inputs (shape $n \times d$) to be selected from.
:param device: Device used for computation of the acquisition function.
:return: Indices of the newly selected batch.
:return: Indices of the newly selected batch and corresponding values of the acquisition function.
"""
if self.force_nonsequential:

Expand Down Expand Up @@ -326,7 +328,9 @@ def compute_fn(

indexed_dataset = _IndexedDataset(dataset)
selected_indices = range(len(dataset))
while len(selected_indices) > batch_size:
while (
len(selected_indices) > batch_size
): # gradually shrinks size of selected batch, until the correct size is reached
data_loader = DataLoader(
indexed_dataset,
batch_size=self.mini_batch_size,
Expand All @@ -335,16 +339,17 @@ def compute_fn(
)

selected_indices = []
selected_values = []
for data, idx in data_loader:
selected_indices.extend(
idx[self.select_from_minibatch(batch_size, model, data, device)]
.cpu()
.tolist()
sub_idx, sub_val = self.select_from_minibatch(
batch_size, model, data, device
)
selected_indices.extend(idx[sub_idx].cpu().tolist())
selected_values.extend(sub_val.cpu().tolist())
if self.subsample:
break
indexed_dataset = Subset(indexed_dataset, selected_indices)
return torch.tensor(selected_indices)
return torch.tensor(selected_indices), torch.tensor(selected_values)


class EmbeddingBased(ABC):
Expand Down
6 changes: 3 additions & 3 deletions afsl/active_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Generic
from typing import Generic, Tuple
import torch
from afsl.acquisition_functions import M, AcquisitionFunction, Targeted
from afsl.acquisition_functions.undirected_vtl import UndirectedVTL
Expand Down Expand Up @@ -130,7 +130,7 @@ def initialize(
device=device,
)

def next(self, model: M | None = None) -> torch.Tensor:
def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`.
Expand All @@ -139,7 +139,7 @@ def next(self, model: M | None = None) -> torch.Tensor:
The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`.
:param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded.
:return: Indices of the selected data.
:return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`.
"""

return self.acquisition_function.select(
Expand Down
4 changes: 2 additions & 2 deletions afsl/adapters/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray]:
if isinstance(self.acquisition_function, Targeted):
self.acquisition_function.set_target(target)

sub_indexes = ActiveDataLoader(
sub_indexes, values = ActiveDataLoader(
dataset=dataset,
batch_size=k,
acquisition_function=self.acquisition_function,
device=self.device,
).next()
return np.array(I[i][sub_indexes]), np.array(V[i][sub_indexes])
return np.array(I[i][sub_indexes]), np.array(values)

resulting_indices = []
resulting_values = []
Expand Down

0 comments on commit 7b4d3ea

Please sign in to comment.