From 433c5119493c41110132bbe8f0c47ff1bb9aa0fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonas=20H=C3=BCbotter?= Date: Mon, 19 Feb 2024 18:06:19 +0100 Subject: [PATCH] refactor initial implementation --- afsl/__init__.py | 10 ++- afsl/acquisition_functions/__init__.py | 72 +++++++++------- afsl/acquisition_functions/bace.py | 75 +++++++++++++--- afsl/acquisition_functions/badge.py | 52 +++--------- .../cosine_similarity.py | 33 ++++--- afsl/acquisition_functions/ctl.py | 5 +- afsl/acquisition_functions/greedy_max_det.py | 21 +---- afsl/acquisition_functions/greedy_max_dist.py | 17 ++-- .../information_density.py | 24 +++--- afsl/acquisition_functions/itl.py | 4 +- afsl/acquisition_functions/max_entropy.py | 5 -- afsl/acquisition_functions/max_margin.py | 5 -- afsl/acquisition_functions/vtl.py | 4 +- afsl/active_data_loader.py | 85 ++++++++++--------- afsl/types.py | 3 - afsl/utils.py | 2 + docs/make.py | 5 +- 17 files changed, 234 insertions(+), 188 deletions(-) delete mode 100644 afsl/types.py diff --git a/afsl/__init__.py b/afsl/__init__.py index d62f096..19de7d5 100644 --- a/afsl/__init__.py +++ b/afsl/__init__.py @@ -1,5 +1,13 @@ """ -Active Few-Shot Learning +*Active Few-Shot Learning* (`afsl`) is a Python package for intelligent active data selection. + +## Why Active Data Selection? + +## Getting Started + +### Installation + +--- """ from afsl.active_data_loader import ActiveDataLoader diff --git a/afsl/acquisition_functions/__init__.py b/afsl/acquisition_functions/__init__.py index 4ffd769..878980d 100644 --- a/afsl/acquisition_functions/__init__.py +++ b/afsl/acquisition_functions/__init__.py @@ -1,60 +1,52 @@ from abc import ABC, abstractmethod +import math from typing import Generic, TypeVar import torch -from afsl.embeddings import M, Embedding -from afsl.types import Target -from afsl.utils import mini_batch_wrapper, mini_batch_wrapper_non_cat +from afsl.embeddings import M +from afsl.utils import ( + DEFAULT_MINI_BATCH_SIZE, + mini_batch_wrapper, + mini_batch_wrapper_non_cat, +) -class AcquisitionFunction(ABC): +class AcquisitionFunction(ABC, Generic[M]): mini_batch_size: int - def __init__(self, mini_batch_size: int = 100): + def __init__(self, mini_batch_size=DEFAULT_MINI_BATCH_SIZE): self.mini_batch_size = mini_batch_size @abstractmethod def select( self, batch_size: int, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, force_nonsequential=False, ) -> torch.Tensor: pass -class BatchAcquisitionFunction(AcquisitionFunction): +class BatchAcquisitionFunction(AcquisitionFunction[M]): @abstractmethod def compute( self, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> torch.Tensor: pass def select( self, batch_size: int, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, force_nonsequential=False, ) -> torch.Tensor: values = mini_batch_wrapper( fn=lambda batch: self.compute( - embedding=embedding, model=model, data=batch, - target=target, - Sigma=Sigma, ), data=data, batch_size=self.mini_batch_size, @@ -66,15 +58,12 @@ def select( State = TypeVar("State") -class SequentialAcquisitionFunction(AcquisitionFunction, Generic[State]): +class SequentialAcquisitionFunction(AcquisitionFunction[M], Generic[M, State]): @abstractmethod def initialize( self, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> State: pass @@ -89,20 +78,14 @@ def step(self, state: State, i: int) -> State: def select( self, batch_size: int, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, force_nonsequential=False, ) -> torch.Tensor: states = mini_batch_wrapper_non_cat( fn=lambda batch: self.initialize( - embedding=embedding, model=model, data=batch, - target=target, - Sigma=Sigma, ), data=data, batch_size=self.mini_batch_size, @@ -116,7 +99,38 @@ def select( indices = [] for _ in range(batch_size): values = torch.cat([self.compute(state) for state in states], dim=0) - i = int(torch.argmax(values).item()) + i = self.selector(values) indices.append(i) states = [self.step(state, i) for state in states] return torch.tensor(indices) + + @staticmethod + def selector(values: torch.Tensor) -> int: + return int(torch.argmax(values).item()) + + +class TargetedAcquisitionFunction(ABC): + target: torch.Tensor + r"""Tensor of prediction targets (shape $m \times d$) or `None` if data selection should be "undirected".""" + + def __init__( + self, + target: torch.Tensor, + subsampled_target_frac: float = 0.5, + max_target_size: int | None = None, + ): + assert target.size(0) > 0, "Target must be non-empty" + assert ( + subsampled_target_frac > 0 and subsampled_target_frac <= 1 + ), "Fraction of target must be in (0, 1]" + assert ( + max_target_size is None or max_target_size > 0 + ), "Max target size must be positive" + + m = self.target.size(0) + max_target_size = max_target_size if max_target_size is not None else m + self.target = target[ + torch.randperm(m)[ + : min(math.ceil(subsampled_target_frac * m), max_target_size) + ] + ] diff --git a/afsl/acquisition_functions/bace.py b/afsl/acquisition_functions/bace.py index 053432e..8da9a51 100644 --- a/afsl/acquisition_functions/bace.py +++ b/afsl/acquisition_functions/bace.py @@ -1,9 +1,13 @@ from typing import NamedTuple import torch -from afsl.acquisition_functions import SequentialAcquisitionFunction +from afsl.acquisition_functions import ( + SequentialAcquisitionFunction, + TargetedAcquisitionFunction, +) from afsl.embeddings import M, Embedding +from afsl.embeddings.latent import LatentEmbedding from afsl.gaussian import GaussianCovarianceMatrix -from afsl.types import Target +from afsl.utils import DEFAULT_MINI_BATCH_SIZE class BaCEState(NamedTuple): @@ -11,29 +15,76 @@ class BaCEState(NamedTuple): n: int -class BaCE(SequentialAcquisitionFunction): +class BaCE(SequentialAcquisitionFunction[M, BaCEState]): + embedding: Embedding[M] + Sigma: torch.Tensor | None noise_std: float - def __init__(self, noise_std=1.0): - super().__init__() + def __init__( + self, + embedding: Embedding[M] = LatentEmbedding(), + Sigma: torch.Tensor | None = None, + noise_std=1.0, + mini_batch_size=DEFAULT_MINI_BATCH_SIZE, + ): + super().__init__(mini_batch_size=mini_batch_size) + self.embedding = embedding + self.Sigma = Sigma self.noise_std = noise_std def initialize( self, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> BaCEState: - assert target is not None, "Target must be non-empty" + n = data.size(0) + data_embeddings = self.embedding.embed(model, data) + covariance_matrix = GaussianCovarianceMatrix.from_embeddings( + noise_std=self.noise_std, Embeddings=data_embeddings, Sigma=self.Sigma + ) + return BaCEState(covariance_matrix=covariance_matrix, n=n) + + def step(self, state: BaCEState, i: int) -> BaCEState: + posterior_covariance_matrix = state.covariance_matrix.condition_on(i) + return BaCEState(covariance_matrix=posterior_covariance_matrix, n=state.n) + +class TargetedBaCE(TargetedAcquisitionFunction, BaCE[M]): + def __init__( + self, + target: torch.Tensor, + embedding: Embedding[M] = LatentEmbedding(), + Sigma: torch.Tensor | None = None, + noise_std=1.0, + subsampled_target_frac: float = 0.5, + max_target_size: int | None = None, + mini_batch_size=DEFAULT_MINI_BATCH_SIZE, + ): + BaCE.__init__( + self, + embedding=embedding, + Sigma=Sigma, + noise_std=noise_std, + mini_batch_size=mini_batch_size, + ) + TargetedAcquisitionFunction.__init__( + self, + target=target, + subsampled_target_frac=subsampled_target_frac, + max_target_size=max_target_size, + ) + + def initialize( + self, + model: M, + data: torch.Tensor, + ) -> BaCEState: n = data.size(0) - data_embeddings = embedding.embed(model, data) - target_embeddings = embedding.embed(model, target) + data_embeddings = self.embedding.embed(model, data) + target_embeddings = self.embedding.embed(model, self.target) joint_embeddings = torch.cat((data_embeddings, target_embeddings)) covariance_matrix = GaussianCovarianceMatrix.from_embeddings( - noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=Sigma + noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=self.Sigma ) return BaCEState(covariance_matrix=covariance_matrix, n=n) diff --git a/afsl/acquisition_functions/badge.py b/afsl/acquisition_functions/badge.py index 0eb47db..c0a75e6 100644 --- a/afsl/acquisition_functions/badge.py +++ b/afsl/acquisition_functions/badge.py @@ -2,8 +2,7 @@ import torch from afsl.acquisition_functions import SequentialAcquisitionFunction from afsl.embeddings import M, Embedding -from afsl.types import Target -from afsl.utils import mini_batch_wrapper_non_cat +from afsl.utils import DEFAULT_MINI_BATCH_SIZE class BADGEState(NamedTuple): @@ -19,16 +18,21 @@ def compute_distances(embeddings, centroids): return min_distances -class BADGE(SequentialAcquisitionFunction[BADGEState]): +class BADGE(SequentialAcquisitionFunction[M, BADGEState]): + embedding: Embedding[M] + + def __init__( + self, embedding: Embedding[M], mini_batch_size=DEFAULT_MINI_BATCH_SIZE + ): + super().__init__(mini_batch_size=mini_batch_size) + self.embedding = embedding + def initialize( self, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> BADGEState: - embeddings = embedding.embed(model, data) + embeddings = self.embedding.embed(model, data) # Choose the first centroid randomly centroid_indices = [ torch.randint(0, embeddings.size(0), (1,)).to(embeddings.device) @@ -49,34 +53,6 @@ def compute(self, state: BADGEState) -> torch.Tensor: probabilities = sqd_distances / sqd_distances.sum() return probabilities - def select( - self, - batch_size: int, - embedding: Embedding[M], - model: M, - data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, - force_nonsequential=False, - ) -> torch.Tensor: - assert not force_nonsequential, "Non-sequential selection is not supported" - - states = mini_batch_wrapper_non_cat( - fn=lambda batch: self.initialize( - embedding=embedding, - model=model, - data=batch, - target=target, - Sigma=Sigma, - ), - data=data, - batch_size=self.mini_batch_size, - ) - - indices = [] - for _ in range(batch_size): - probabilities = torch.cat([self.compute(state) for state in states], dim=0) - i = int(torch.multinomial(probabilities, num_samples=1).item()) - indices.append(i) - states = [self.step(state, i) for state in states] - return torch.tensor(indices) + @staticmethod + def selector(probabilities: torch.Tensor) -> int: + return int(torch.multinomial(probabilities, num_samples=1).item()) diff --git a/afsl/acquisition_functions/cosine_similarity.py b/afsl/acquisition_functions/cosine_similarity.py index d57be43..2841b36 100644 --- a/afsl/acquisition_functions/cosine_similarity.py +++ b/afsl/acquisition_functions/cosine_similarity.py @@ -1,28 +1,39 @@ import torch import torch.nn.functional as F -from afsl.acquisition_functions import BatchAcquisitionFunction -from afsl.embeddings import Embedding +from afsl.acquisition_functions import ( + BatchAcquisitionFunction, + TargetedAcquisitionFunction, +) from afsl.model import LatentModel -from afsl.types import Target -from afsl.utils import get_device +from afsl.utils import DEFAULT_MINI_BATCH_SIZE, get_device -class CosineSimilarity(BatchAcquisitionFunction): +class CosineSimilarity(TargetedAcquisitionFunction, BatchAcquisitionFunction): + def __init__( + self, + target: torch.Tensor, + subsampled_target_frac: float = 0.5, + max_target_size: int | None = None, + mini_batch_size=DEFAULT_MINI_BATCH_SIZE, + ): + BatchAcquisitionFunction.__init__(self, mini_batch_size=mini_batch_size) + TargetedAcquisitionFunction.__init__( + self, + target=target, + subsampled_target_frac=subsampled_target_frac, + max_target_size=max_target_size, + ) + def compute( self, - embedding: Embedding, model: LatentModel, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> torch.Tensor: - assert target is not None, "Target must be non-empty" - model.eval() device = get_device(model) with torch.no_grad(): data_latent = model.latent(data.to(device)) - target_latent = model.latent(target.to(device)) + target_latent = model.latent(self.target.to(device)) data_latent_normalized = F.normalize(data_latent, p=2, dim=1) target_latent_normalized = F.normalize(target_latent, p=2, dim=1) diff --git a/afsl/acquisition_functions/ctl.py b/afsl/acquisition_functions/ctl.py index c4115f6..d6640ce 100644 --- a/afsl/acquisition_functions/ctl.py +++ b/afsl/acquisition_functions/ctl.py @@ -1,9 +1,8 @@ import torch -import wandb -from afsl.acquisition_functions.bace import BaCE, BaCEState +from afsl.acquisition_functions.bace import TargetedBaCE, BaCEState -class CTL(BaCE): +class CTL(TargetedBaCE): def compute(self, state: BaCEState) -> torch.Tensor: ind_a = torch.arange(state.n) ind_b = torch.arange(state.n, state.covariance_matrix.dim) diff --git a/afsl/acquisition_functions/greedy_max_det.py b/afsl/acquisition_functions/greedy_max_det.py index eec3ceb..2a089df 100644 --- a/afsl/acquisition_functions/greedy_max_det.py +++ b/afsl/acquisition_functions/greedy_max_det.py @@ -1,12 +1,10 @@ import torch import wandb from afsl.acquisition_functions.bace import BaCE, BaCEState -from afsl.embeddings import M, Embedding -from afsl.gaussian import GaussianCovarianceMatrix -from afsl.types import Target +from afsl.embeddings import M -class GreedyMaxDet(BaCE): +class GreedyMaxDet(BaCE[M]): def compute(self, state: BaCEState) -> torch.Tensor: variances = torch.diag(state.covariance_matrix[:, :]) wandb.log( @@ -16,18 +14,3 @@ def compute(self, state: BaCEState) -> torch.Tensor: } ) return variances - - def initialize( - self, - embedding: Embedding[M], - model: M, - data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, - ) -> BaCEState: - n = data.size(0) - data_embeddings = embedding.embed(model, data) - covariance_matrix = GaussianCovarianceMatrix.from_embeddings( - noise_std=self.noise_std, Embeddings=data_embeddings, Sigma=Sigma - ) - return BaCEState(covariance_matrix=covariance_matrix, n=n) diff --git a/afsl/acquisition_functions/greedy_max_dist.py b/afsl/acquisition_functions/greedy_max_dist.py index a6c3b3c..f745457 100644 --- a/afsl/acquisition_functions/greedy_max_dist.py +++ b/afsl/acquisition_functions/greedy_max_dist.py @@ -3,7 +3,7 @@ from afsl.acquisition_functions import SequentialAcquisitionFunction from afsl.acquisition_functions.badge import compute_distances from afsl.embeddings import M, Embedding -from afsl.types import Target +from afsl.utils import DEFAULT_MINI_BATCH_SIZE class GreedyMaxDistState(NamedTuple): @@ -11,16 +11,21 @@ class GreedyMaxDistState(NamedTuple): centroid_indices: List[torch.Tensor] -class GreedyMaxDist(SequentialAcquisitionFunction[GreedyMaxDistState]): +class GreedyMaxDist(SequentialAcquisitionFunction[M, GreedyMaxDistState]): + embedding: Embedding[M] + + def __init__( + self, embedding: Embedding[M], mini_batch_size=DEFAULT_MINI_BATCH_SIZE + ): + super().__init__(mini_batch_size=mini_batch_size) + self.embedding = embedding + def initialize( self, - embedding: Embedding[M], model: M, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> GreedyMaxDistState: - embeddings = embedding.embed(model, data) + embeddings = self.embedding.embed(model, data) # Choose the first centroid randomly centroid_indices = [ torch.randint(0, embeddings.size(0), (1,)).to(embeddings.device) diff --git a/afsl/acquisition_functions/information_density.py b/afsl/acquisition_functions/information_density.py index 5dd2ac5..2ca4291 100644 --- a/afsl/acquisition_functions/information_density.py +++ b/afsl/acquisition_functions/information_density.py @@ -2,24 +2,26 @@ from afsl.acquisition_functions import BatchAcquisitionFunction from afsl.acquisition_functions.cosine_similarity import CosineSimilarity from afsl.acquisition_functions.max_entropy import MaxEntropy -from afsl.embeddings import Embedding from afsl.model import LatentModel -from afsl.types import Target +from afsl.utils import DEFAULT_MINI_BATCH_SIZE class InformationDensity(BatchAcquisitionFunction): + cosine_similarity: CosineSimilarity + max_entropy: MaxEntropy + + def __init__(self, target: torch.Tensor, mini_batch_size=DEFAULT_MINI_BATCH_SIZE): + super().__init__(mini_batch_size=mini_batch_size) + self.cosine_similarity = CosineSimilarity( + target=target, mini_batch_size=mini_batch_size + ) + self.max_entropy = MaxEntropy(mini_batch_size=mini_batch_size) + def compute( self, - embedding: Embedding, model: LatentModel, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> torch.Tensor: - entropy = MaxEntropy(mini_batch_size=self.mini_batch_size).compute( - embedding, model, data, target, Sigma - ) - cosine_similarity = CosineSimilarity( - mini_batch_size=self.mini_batch_size - ).compute(embedding, model, data, target, Sigma) + entropy = self.max_entropy.compute(model, data) + cosine_similarity = self.cosine_similarity.compute(model, data) return entropy * cosine_similarity diff --git a/afsl/acquisition_functions/itl.py b/afsl/acquisition_functions/itl.py index 521bca0..ffffc86 100644 --- a/afsl/acquisition_functions/itl.py +++ b/afsl/acquisition_functions/itl.py @@ -1,9 +1,9 @@ import torch import wandb -from afsl.acquisition_functions.bace import BaCE, BaCEState +from afsl.acquisition_functions.bace import TargetedBaCE, BaCEState -class ITL(BaCE): +class ITL(TargetedBaCE): def compute(self, state: BaCEState) -> torch.Tensor: variances = torch.diag(state.covariance_matrix[: state.n, : state.n]) conditional_covariance_matrix = state.covariance_matrix.condition_on( diff --git a/afsl/acquisition_functions/max_entropy.py b/afsl/acquisition_functions/max_entropy.py index 0e5ccac..dcdae00 100644 --- a/afsl/acquisition_functions/max_entropy.py +++ b/afsl/acquisition_functions/max_entropy.py @@ -1,19 +1,14 @@ import torch from afsl.acquisition_functions import BatchAcquisitionFunction -from afsl.embeddings import Embedding from afsl.model import Model -from afsl.types import Target from afsl.utils import get_device class MaxEntropy(BatchAcquisitionFunction): def compute( self, - embedding: Embedding, model: Model, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> torch.Tensor: model.eval() with torch.no_grad(): diff --git a/afsl/acquisition_functions/max_margin.py b/afsl/acquisition_functions/max_margin.py index bf03130..1f92d13 100644 --- a/afsl/acquisition_functions/max_margin.py +++ b/afsl/acquisition_functions/max_margin.py @@ -1,19 +1,14 @@ import torch from afsl.acquisition_functions import BatchAcquisitionFunction -from afsl.embeddings import Embedding from afsl.model import Model -from afsl.types import Target from afsl.utils import get_device class MaxMargin(BatchAcquisitionFunction): def compute( self, - embedding: Embedding, model: Model, data: torch.Tensor, - target: Target, - Sigma: torch.Tensor | None = None, ) -> torch.Tensor: model.eval() with torch.no_grad(): diff --git a/afsl/acquisition_functions/vtl.py b/afsl/acquisition_functions/vtl.py index 1601541..49098bd 100644 --- a/afsl/acquisition_functions/vtl.py +++ b/afsl/acquisition_functions/vtl.py @@ -1,9 +1,9 @@ import torch import wandb -from afsl.acquisition_functions.bace import BaCE, BaCEState +from afsl.acquisition_functions.bace import TargetedBaCE, BaCEState -class VTL(BaCE): +class VTL(TargetedBaCE): def compute(self, state: BaCEState) -> torch.Tensor: noise_var = self.noise_std**2 diff --git a/afsl/active_data_loader.py b/afsl/active_data_loader.py index 7d4fcf1..5aaed10 100644 --- a/afsl/active_data_loader.py +++ b/afsl/active_data_loader.py @@ -1,70 +1,77 @@ -import math from typing import Generic import torch from afsl.acquisition_functions import AcquisitionFunction from afsl.acquisition_functions.itl import ITL -from afsl.embeddings import M, Embedding +from afsl.embeddings import M from afsl.embeddings.latent import LatentEmbedding -from afsl.types import Target class ActiveDataLoader(Generic[M]): + r""" + `ActiveDataLoader` can be used as a drop-in replacement for random data selection: + + ```python + data_loader = ActiveDataLoader.initialize(data, target, batch_size=64) + batch = data[data_loader.next(model)] + ``` + + where + - `model` is a PyTorch `nn.Module`, + - `data` is a tensor of inputs (shape $n \times d$), and + - `target` is a tensor of prediction targets (shape $m \times d$) or `None`. + """ + data: torch.Tensor - target: Target + r"""Tensor of inputs (shape $n \times d$) to be selected from.""" + batch_size: int + r"""Size of the batch to be selected.""" + acquisition_function: AcquisitionFunction - embedding: Embedding[M] + r"""Acquisition function to be used for data selection.""" + subsampled_target_frac: float max_target_size: int | None def __init__( self, data: torch.Tensor, - target: Target, batch_size: int, - acquisition_function: AcquisitionFunction = ITL(), - embedding: Embedding[M] = LatentEmbedding(), - subsampled_target_frac: float = 0.5, - max_target_size: int | None = None, + acquisition_function: AcquisitionFunction, ): assert data.size(0) > 0, "Data must be non-empty" assert batch_size > 0, "Batch size must be positive" - assert ( - subsampled_target_frac > 0 and subsampled_target_frac <= 1 - ), "Fraction of target must be in (0, 1]" - assert ( - max_target_size is None or max_target_size > 0 - ), "Max target size must be positive" self.data = data - self.target = target self.batch_size = batch_size self.acquisition_function = acquisition_function - self.embedding = embedding - self.subsampled_target_frac = subsampled_target_frac - self.max_target_size = max_target_size - def next(self, model: M, Sigma: torch.Tensor | None = None) -> torch.Tensor: - target = self._subsample_target() - return self.acquisition_function.select( - batch_size=self.batch_size, - embedding=self.embedding, - model=model, - data=self.data, + @classmethod + def initialize( + cls, + data: torch.Tensor, + target: torch.Tensor, + batch_size: int, + Sigma: torch.Tensor | None = None, + subsampled_target_frac: float = 0.5, + max_target_size: int | None = None, + ): + acquisition_function = ITL( target=target, + embedding=LatentEmbedding(), Sigma=Sigma, + subsampled_target_frac=subsampled_target_frac, + max_target_size=max_target_size, + ) + return cls( + data=data, + batch_size=batch_size, + acquisition_function=acquisition_function, ) - def _subsample_target(self) -> Target: - if self.target is None: - return None - - m = self.target.size(0) - max_target_size = ( - self.max_target_size if self.max_target_size is not None else m + def next(self, model: M) -> torch.Tensor: + return self.acquisition_function.select( + batch_size=self.batch_size, + model=model, + data=self.data, ) - return self.target[ - torch.randperm(m)[ - : min(math.ceil(self.subsampled_target_frac * m), max_target_size) - ] - ] diff --git a/afsl/types.py b/afsl/types.py deleted file mode 100644 index 88337c3..0000000 --- a/afsl/types.py +++ /dev/null @@ -1,3 +0,0 @@ -import torch - -Target = torch.Tensor | None diff --git a/afsl/utils.py b/afsl/utils.py index 8331988..17ccb02 100644 --- a/afsl/utils.py +++ b/afsl/utils.py @@ -1,6 +1,8 @@ import torch from afsl.model import Model +DEFAULT_MINI_BATCH_SIZE = 100 + def get_device(model: Model): return next(model.parameters()).device diff --git a/docs/make.py b/docs/make.py index d1c5483..d630606 100755 --- a/docs/make.py +++ b/docs/make.py @@ -49,8 +49,9 @@ # Render main docs pdoc.render.configure( edit_url_map={ - "afsl": "https://github.com/jonhue/afsl/", - } + "afsl": "https://github.com/jonhue/afsl/docs", + }, + math=True, ) pdoc.pdoc( here / ".." / "afsl",