Skip to content

Commit

Permalink
refactor initial implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jonhue committed Feb 19, 2024
1 parent 91c3e63 commit 433c511
Show file tree
Hide file tree
Showing 17 changed files with 234 additions and 188 deletions.
10 changes: 9 additions & 1 deletion afsl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
"""
Active Few-Shot Learning
*Active Few-Shot Learning* (`afsl`) is a Python package for intelligent active data selection.
## Why Active Data Selection?
## Getting Started
### Installation
---
"""

from afsl.active_data_loader import ActiveDataLoader
Expand Down
72 changes: 43 additions & 29 deletions afsl/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,52 @@
from abc import ABC, abstractmethod
import math
from typing import Generic, TypeVar
import torch
from afsl.embeddings import M, Embedding
from afsl.types import Target
from afsl.utils import mini_batch_wrapper, mini_batch_wrapper_non_cat
from afsl.embeddings import M
from afsl.utils import (
DEFAULT_MINI_BATCH_SIZE,
mini_batch_wrapper,
mini_batch_wrapper_non_cat,
)


class AcquisitionFunction(ABC):
class AcquisitionFunction(ABC, Generic[M]):
mini_batch_size: int

def __init__(self, mini_batch_size: int = 100):
def __init__(self, mini_batch_size=DEFAULT_MINI_BATCH_SIZE):
self.mini_batch_size = mini_batch_size

@abstractmethod
def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
pass


class BatchAcquisitionFunction(AcquisitionFunction):
class BatchAcquisitionFunction(AcquisitionFunction[M]):
@abstractmethod
def compute(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> torch.Tensor:
pass

def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
values = mini_batch_wrapper(
fn=lambda batch: self.compute(
embedding=embedding,
model=model,
data=batch,
target=target,
Sigma=Sigma,
),
data=data,
batch_size=self.mini_batch_size,
Expand All @@ -66,15 +58,12 @@ def select(
State = TypeVar("State")


class SequentialAcquisitionFunction(AcquisitionFunction, Generic[State]):
class SequentialAcquisitionFunction(AcquisitionFunction[M], Generic[M, State]):
@abstractmethod
def initialize(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> State:
pass

Expand All @@ -89,20 +78,14 @@ def step(self, state: State, i: int) -> State:
def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
states = mini_batch_wrapper_non_cat(
fn=lambda batch: self.initialize(
embedding=embedding,
model=model,
data=batch,
target=target,
Sigma=Sigma,
),
data=data,
batch_size=self.mini_batch_size,
Expand All @@ -116,7 +99,38 @@ def select(
indices = []
for _ in range(batch_size):
values = torch.cat([self.compute(state) for state in states], dim=0)
i = int(torch.argmax(values).item())
i = self.selector(values)
indices.append(i)
states = [self.step(state, i) for state in states]
return torch.tensor(indices)

@staticmethod
def selector(values: torch.Tensor) -> int:
return int(torch.argmax(values).item())


class TargetedAcquisitionFunction(ABC):
target: torch.Tensor
r"""Tensor of prediction targets (shape $m \times d$) or `None` if data selection should be "undirected"."""

def __init__(
self,
target: torch.Tensor,
subsampled_target_frac: float = 0.5,
max_target_size: int | None = None,
):
assert target.size(0) > 0, "Target must be non-empty"
assert (
subsampled_target_frac > 0 and subsampled_target_frac <= 1
), "Fraction of target must be in (0, 1]"
assert (
max_target_size is None or max_target_size > 0
), "Max target size must be positive"

m = self.target.size(0)
max_target_size = max_target_size if max_target_size is not None else m
self.target = target[
torch.randperm(m)[
: min(math.ceil(subsampled_target_frac * m), max_target_size)
]
]
75 changes: 63 additions & 12 deletions afsl/acquisition_functions/bace.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,90 @@
from typing import NamedTuple
import torch
from afsl.acquisition_functions import SequentialAcquisitionFunction
from afsl.acquisition_functions import (
SequentialAcquisitionFunction,
TargetedAcquisitionFunction,
)
from afsl.embeddings import M, Embedding
from afsl.embeddings.latent import LatentEmbedding
from afsl.gaussian import GaussianCovarianceMatrix
from afsl.types import Target
from afsl.utils import DEFAULT_MINI_BATCH_SIZE


class BaCEState(NamedTuple):
covariance_matrix: GaussianCovarianceMatrix
n: int


class BaCE(SequentialAcquisitionFunction):
class BaCE(SequentialAcquisitionFunction[M, BaCEState]):
embedding: Embedding[M]
Sigma: torch.Tensor | None
noise_std: float

def __init__(self, noise_std=1.0):
super().__init__()
def __init__(
self,
embedding: Embedding[M] = LatentEmbedding(),
Sigma: torch.Tensor | None = None,
noise_std=1.0,
mini_batch_size=DEFAULT_MINI_BATCH_SIZE,
):
super().__init__(mini_batch_size=mini_batch_size)
self.embedding = embedding
self.Sigma = Sigma
self.noise_std = noise_std

def initialize(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> BaCEState:
assert target is not None, "Target must be non-empty"
n = data.size(0)
data_embeddings = self.embedding.embed(model, data)
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=data_embeddings, Sigma=self.Sigma
)
return BaCEState(covariance_matrix=covariance_matrix, n=n)

def step(self, state: BaCEState, i: int) -> BaCEState:
posterior_covariance_matrix = state.covariance_matrix.condition_on(i)
return BaCEState(covariance_matrix=posterior_covariance_matrix, n=state.n)


class TargetedBaCE(TargetedAcquisitionFunction, BaCE[M]):
def __init__(
self,
target: torch.Tensor,
embedding: Embedding[M] = LatentEmbedding(),
Sigma: torch.Tensor | None = None,
noise_std=1.0,
subsampled_target_frac: float = 0.5,
max_target_size: int | None = None,
mini_batch_size=DEFAULT_MINI_BATCH_SIZE,
):
BaCE.__init__(
self,
embedding=embedding,
Sigma=Sigma,
noise_std=noise_std,
mini_batch_size=mini_batch_size,
)
TargetedAcquisitionFunction.__init__(
self,
target=target,
subsampled_target_frac=subsampled_target_frac,
max_target_size=max_target_size,
)

def initialize(
self,
model: M,
data: torch.Tensor,
) -> BaCEState:
n = data.size(0)
data_embeddings = embedding.embed(model, data)
target_embeddings = embedding.embed(model, target)
data_embeddings = self.embedding.embed(model, data)
target_embeddings = self.embedding.embed(model, self.target)
joint_embeddings = torch.cat((data_embeddings, target_embeddings))
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=Sigma
noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=self.Sigma
)
return BaCEState(covariance_matrix=covariance_matrix, n=n)

Expand Down
52 changes: 14 additions & 38 deletions afsl/acquisition_functions/badge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import torch
from afsl.acquisition_functions import SequentialAcquisitionFunction
from afsl.embeddings import M, Embedding
from afsl.types import Target
from afsl.utils import mini_batch_wrapper_non_cat
from afsl.utils import DEFAULT_MINI_BATCH_SIZE


class BADGEState(NamedTuple):
Expand All @@ -19,16 +18,21 @@ def compute_distances(embeddings, centroids):
return min_distances


class BADGE(SequentialAcquisitionFunction[BADGEState]):
class BADGE(SequentialAcquisitionFunction[M, BADGEState]):
embedding: Embedding[M]

def __init__(
self, embedding: Embedding[M], mini_batch_size=DEFAULT_MINI_BATCH_SIZE
):
super().__init__(mini_batch_size=mini_batch_size)
self.embedding = embedding

def initialize(
self,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> BADGEState:
embeddings = embedding.embed(model, data)
embeddings = self.embedding.embed(model, data)
# Choose the first centroid randomly
centroid_indices = [
torch.randint(0, embeddings.size(0), (1,)).to(embeddings.device)
Expand All @@ -49,34 +53,6 @@ def compute(self, state: BADGEState) -> torch.Tensor:
probabilities = sqd_distances / sqd_distances.sum()
return probabilities

def select(
self,
batch_size: int,
embedding: Embedding[M],
model: M,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
force_nonsequential=False,
) -> torch.Tensor:
assert not force_nonsequential, "Non-sequential selection is not supported"

states = mini_batch_wrapper_non_cat(
fn=lambda batch: self.initialize(
embedding=embedding,
model=model,
data=batch,
target=target,
Sigma=Sigma,
),
data=data,
batch_size=self.mini_batch_size,
)

indices = []
for _ in range(batch_size):
probabilities = torch.cat([self.compute(state) for state in states], dim=0)
i = int(torch.multinomial(probabilities, num_samples=1).item())
indices.append(i)
states = [self.step(state, i) for state in states]
return torch.tensor(indices)
@staticmethod
def selector(probabilities: torch.Tensor) -> int:
return int(torch.multinomial(probabilities, num_samples=1).item())
33 changes: 22 additions & 11 deletions afsl/acquisition_functions/cosine_similarity.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
import torch
import torch.nn.functional as F
from afsl.acquisition_functions import BatchAcquisitionFunction
from afsl.embeddings import Embedding
from afsl.acquisition_functions import (
BatchAcquisitionFunction,
TargetedAcquisitionFunction,
)
from afsl.model import LatentModel
from afsl.types import Target
from afsl.utils import get_device
from afsl.utils import DEFAULT_MINI_BATCH_SIZE, get_device


class CosineSimilarity(BatchAcquisitionFunction):
class CosineSimilarity(TargetedAcquisitionFunction, BatchAcquisitionFunction):
def __init__(
self,
target: torch.Tensor,
subsampled_target_frac: float = 0.5,
max_target_size: int | None = None,
mini_batch_size=DEFAULT_MINI_BATCH_SIZE,
):
BatchAcquisitionFunction.__init__(self, mini_batch_size=mini_batch_size)
TargetedAcquisitionFunction.__init__(
self,
target=target,
subsampled_target_frac=subsampled_target_frac,
max_target_size=max_target_size,
)

def compute(
self,
embedding: Embedding,
model: LatentModel,
data: torch.Tensor,
target: Target,
Sigma: torch.Tensor | None = None,
) -> torch.Tensor:
assert target is not None, "Target must be non-empty"

model.eval()
device = get_device(model)
with torch.no_grad():
data_latent = model.latent(data.to(device))
target_latent = model.latent(target.to(device))
target_latent = model.latent(self.target.to(device))

data_latent_normalized = F.normalize(data_latent, p=2, dim=1)
target_latent_normalized = F.normalize(target_latent, p=2, dim=1)
Expand Down
Loading

0 comments on commit 433c511

Please sign in to comment.