Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

return proper acquisition function value #56

Merged
merged 2 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions afsl/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,15 @@ def select(
model: M,
dataset: Dataset,
device: torch.device | None = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch.

:param batch_size: Size of the batch to be selected. Needs to be smaller than `mini_batch_size`.
:param model: Model used for data selection.
:param dataset: Inputs (shape $n \times d$) to be selected from.
:param device: Device used for computation of the acquisition function.
:return: Indices of the newly selected batch.
:return: Indices of the newly selected batch and corresponding values of the acquisition function.
"""
pass

Expand Down Expand Up @@ -141,7 +141,7 @@ def select(
model: M,
dataset: Dataset,
device: torch.device | None = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
return BatchAcquisitionFunction._select(
compute_fn=self.compute,
batch_size=batch_size,
Expand All @@ -163,7 +163,7 @@ def _select(
mini_batch_size: int,
num_workers: int,
subsample: bool,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
indexed_dataset = _IndexedDataset(dataset)
data_loader = DataLoader(
indexed_dataset,
Expand All @@ -182,8 +182,8 @@ def _select(
values = torch.cat(_values)
original_indices = torch.cat(_original_indices)

_, indices = torch.topk(values, batch_size)
return original_indices[indices.cpu()]
values, indices = torch.topk(values, batch_size)
return original_indices[indices.cpu()], values.cpu()


State = TypeVar("State")
Expand Down Expand Up @@ -261,33 +261,35 @@ def selector(values: torch.Tensor) -> int:

def select_from_minibatch(
self, batch_size: int, model: M, data: torch.Tensor, device: torch.device | None
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch from the given mini batch `data`.

:param batch_size: Size of the batch to be selected. Needs to be smaller than `mini_batch_size`.
:param model: Model used for data selection.
:param data: Mini batch of inputs (shape $n \times d$) to be selected from.
:param device: Device used for computation of the acquisition function.
:return: Indices of the newly selected batch (with respect to mini batch).
:return: Indices of the newly selected batch (with respect to mini batch) and corresponding values of the acquisition function.
"""
state = self.initialize(model, data, device)

indices = []
selected_indices = []
selected_values = []
for _ in range(batch_size):
values = self.compute(state)
i = self.selector(values)
indices.append(i)
selected_indices.append(i)
selected_values.append(values[i])
state = self.step(state, i)
return torch.tensor(indices)
return torch.tensor(selected_indices), torch.tensor(selected_values)

def select(
self,
batch_size: int,
model: M,
dataset: Dataset,
device: torch.device | None = None,
) -> torch.Tensor:
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch. If `force_nonsequential` is `True`, the data is selected analogously to `BatchAcquisitionFunction.select`.
Otherwise, the data is selected by hierarchical composition of data selected from mini batches.
Expand All @@ -296,7 +298,7 @@ def select(
:param model: Model used for data selection.
:param dataset: Inputs (shape $n \times d$) to be selected from.
:param device: Device used for computation of the acquisition function.
:return: Indices of the newly selected batch.
:return: Indices of the newly selected batch and corresponding values of the acquisition function.
"""
if self.force_nonsequential:

Expand Down Expand Up @@ -326,7 +328,9 @@ def compute_fn(

indexed_dataset = _IndexedDataset(dataset)
selected_indices = range(len(dataset))
while len(selected_indices) > batch_size:
while (
len(selected_indices) > batch_size
): # gradually shrinks size of selected batch, until the correct size is reached
data_loader = DataLoader(
indexed_dataset,
batch_size=self.mini_batch_size,
Expand All @@ -335,16 +339,17 @@ def compute_fn(
)

selected_indices = []
selected_values = []
for data, idx in data_loader:
selected_indices.extend(
idx[self.select_from_minibatch(batch_size, model, data, device)]
.cpu()
.tolist()
sub_idx, sub_val = self.select_from_minibatch(
batch_size, model, data, device
)
selected_indices.extend(idx[sub_idx].cpu().tolist())
selected_values.extend(sub_val.cpu().tolist())
if self.subsample:
break
indexed_dataset = Subset(indexed_dataset, selected_indices)
return torch.tensor(selected_indices)
return torch.tensor(selected_indices), torch.tensor(selected_values)


class EmbeddingBased(ABC):
Expand Down
6 changes: 3 additions & 3 deletions afsl/active_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Generic
from typing import Generic, Tuple
import torch
from afsl.acquisition_functions import M, AcquisitionFunction, Targeted
from afsl.acquisition_functions.undirected_vtl import UndirectedVTL
Expand Down Expand Up @@ -130,7 +130,7 @@ def initialize(
device=device,
)

def next(self, model: M | None = None) -> torch.Tensor:
def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`.

Expand All @@ -139,7 +139,7 @@ def next(self, model: M | None = None) -> torch.Tensor:
The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`.

:param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded.
:return: Indices of the selected data.
:return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`.
"""

return self.acquisition_function.select(
Expand Down
4 changes: 2 additions & 2 deletions afsl/adapters/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ def engine(i: int) -> Tuple[np.ndarray, np.ndarray]:
if isinstance(self.acquisition_function, Targeted):
self.acquisition_function.set_target(target)

sub_indexes = ActiveDataLoader(
sub_indexes, values = ActiveDataLoader(
dataset=dataset,
batch_size=k,
acquisition_function=self.acquisition_function,
device=self.device,
).next()
return np.array(I[i][sub_indexes]), np.array(V[i][sub_indexes])
return np.array(I[i][sub_indexes]), np.array(values)

resulting_indices = []
resulting_values = []
Expand Down