Skip to content

Commit

Permalink
more refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
jonhue committed Feb 20, 2024
1 parent da4f868 commit 70d0d19
Show file tree
Hide file tree
Showing 21 changed files with 253 additions and 278 deletions.
4 changes: 3 additions & 1 deletion afsl/acquisition_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import math
from typing import Generic, TypeVar
import torch
from afsl.embeddings import M
from afsl.model import Model
from afsl.utils import (
DEFAULT_MINI_BATCH_SIZE,
mini_batch_wrapper,
mini_batch_wrapper_non_cat,
)

M = TypeVar("M", bound=Model)


class AcquisitionFunction(ABC, Generic[M]):
mini_batch_size: int
Expand Down
55 changes: 33 additions & 22 deletions afsl/acquisition_functions/bace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,56 +4,58 @@
SequentialAcquisitionFunction,
Targeted,
)
from afsl.embeddings import M, Embedding
from afsl.embeddings.provided import ProvidedEmbedding
from afsl.gaussian import GaussianCovarianceMatrix
from afsl.utils import DEFAULT_MINI_BATCH_SIZE
from afsl.model import ModelWithEmbeddingOrKernel, ModelWithKernel
from afsl.utils import DEFAULT_MINI_BATCH_SIZE, compute_embedding


class BaCEState(NamedTuple):
covariance_matrix: GaussianCovarianceMatrix
n: int


class BaCE(SequentialAcquisitionFunction[M, BaCEState]):
embedding: Embedding[M]
class BaCE(SequentialAcquisitionFunction[ModelWithEmbeddingOrKernel, BaCEState]):
Sigma: torch.Tensor | None
noise_std: float

def __init__(
self,
embedding: Embedding[M] = ProvidedEmbedding(),
Sigma: torch.Tensor | None = None,
noise_std=1.0,
mini_batch_size=DEFAULT_MINI_BATCH_SIZE,
):
super().__init__(mini_batch_size=mini_batch_size)
self.embedding = embedding
self.Sigma = Sigma
self.noise_std = noise_std

def initialize(
self,
model: M,
model: ModelWithEmbeddingOrKernel,
data: torch.Tensor,
) -> BaCEState:
n = data.size(0)
data_embeddings = self.embedding.embed(model, data)
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=data_embeddings, Sigma=self.Sigma
)
if isinstance(model, ModelWithKernel):
covariance_matrix = GaussianCovarianceMatrix(
model.kernel(data, None), noise_std=self.noise_std
)
else:
data_embeddings = compute_embedding(
model, data, mini_batch_size=self.mini_batch_size
)
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=data_embeddings, Sigma=self.Sigma
)
return BaCEState(covariance_matrix=covariance_matrix, n=n)

def step(self, state: BaCEState, i: int) -> BaCEState:
posterior_covariance_matrix = state.covariance_matrix.condition_on(i)
return BaCEState(covariance_matrix=posterior_covariance_matrix, n=state.n)


class TargetedBaCE(Targeted, BaCE[M]):
class TargetedBaCE(Targeted, BaCE):
def __init__(
self,
target: torch.Tensor,
embedding: Embedding[M] = ProvidedEmbedding(),
Sigma: torch.Tensor | None = None,
noise_std=1.0,
subsampled_target_frac: float = 0.5,
Expand All @@ -62,7 +64,6 @@ def __init__(
):
BaCE.__init__(
self,
embedding=embedding,
Sigma=Sigma,
noise_std=noise_std,
mini_batch_size=mini_batch_size,
Expand All @@ -76,16 +77,26 @@ def __init__(

def initialize(
self,
model: M,
model: ModelWithEmbeddingOrKernel,
data: torch.Tensor,
) -> BaCEState:
n = data.size(0)
data_embeddings = self.embedding.embed(model, data)
target_embeddings = self.embedding.embed(model, self.target)
joint_embeddings = torch.cat((data_embeddings, target_embeddings))
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=self.Sigma
)
if isinstance(model, ModelWithKernel):
covariance_matrix = GaussianCovarianceMatrix(
model.kernel(torch.cat((data, self.target)), None),
noise_std=self.noise_std,
)
else:
data_embeddings = compute_embedding(
model, data=data, mini_batch_size=self.mini_batch_size
)
target_embeddings = compute_embedding(
model, data=self.target, mini_batch_size=self.mini_batch_size
)
joint_embeddings = torch.cat((data_embeddings, target_embeddings))
covariance_matrix = GaussianCovarianceMatrix.from_embeddings(
noise_std=self.noise_std, Embeddings=joint_embeddings, Sigma=self.Sigma
)
return BaCEState(covariance_matrix=covariance_matrix, n=n)

def step(self, state: BaCEState, i: int) -> BaCEState:
Expand Down
20 changes: 7 additions & 13 deletions afsl/acquisition_functions/badge.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, NamedTuple
import torch
from afsl.acquisition_functions import SequentialAcquisitionFunction
from afsl.embeddings import M, Embedding
from afsl.utils import DEFAULT_MINI_BATCH_SIZE
from afsl.model import ModelWithEmbedding
from afsl.utils import compute_embedding


class BADGEState(NamedTuple):
Expand All @@ -18,21 +18,15 @@ def compute_distances(embeddings, centroids):
return min_distances


class BADGE(SequentialAcquisitionFunction[M, BADGEState]):
embedding: Embedding[M]

def __init__(
self, embedding: Embedding[M], mini_batch_size=DEFAULT_MINI_BATCH_SIZE
):
super().__init__(mini_batch_size=mini_batch_size)
self.embedding = embedding

class BADGE(SequentialAcquisitionFunction[ModelWithEmbedding, BADGEState]):
def initialize(
self,
model: M,
model: ModelWithEmbedding,
data: torch.Tensor,
) -> BADGEState:
embeddings = self.embedding.embed(model, data)
embeddings = compute_embedding(
model, data, mini_batch_size=self.mini_batch_size
)
# Choose the first centroid randomly
centroid_indices = [
torch.randint(0, embeddings.size(0), (1,)).to(embeddings.device)
Expand Down
10 changes: 7 additions & 3 deletions afsl/acquisition_functions/cosine_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Targeted,
)
from afsl.model import ModelWithEmbedding
from afsl.utils import DEFAULT_MINI_BATCH_SIZE, get_device
from afsl.utils import DEFAULT_MINI_BATCH_SIZE, compute_embedding, get_device


class CosineSimilarity(Targeted, BatchAcquisitionFunction):
Expand All @@ -32,8 +32,12 @@ def compute(
model.eval()
device = get_device(model)
with torch.no_grad():
data_latent = model.embed(data.to(device))
target_latent = model.embed(self.target.to(device))
data_latent = compute_embedding(
model, data=data, mini_batch_size=self.mini_batch_size
)
target_latent = compute_embedding(
model, data=self.target, mini_batch_size=self.mini_batch_size
)

data_latent_normalized = F.normalize(data_latent, p=2, dim=1)
target_latent_normalized = F.normalize(target_latent, p=2, dim=1)
Expand Down
3 changes: 1 addition & 2 deletions afsl/acquisition_functions/greedy_max_det.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import torch
import wandb
from afsl.acquisition_functions.bace import BaCE, BaCEState
from afsl.embeddings import M


class GreedyMaxDet(BaCE[M]):
class GreedyMaxDet(BaCE):
def compute(self, state: BaCEState) -> torch.Tensor:
variances = torch.diag(state.covariance_matrix[:, :])
wandb.log(
Expand Down
22 changes: 9 additions & 13 deletions afsl/acquisition_functions/greedy_max_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,26 @@
import torch
from afsl.acquisition_functions import SequentialAcquisitionFunction
from afsl.acquisition_functions.badge import compute_distances
from afsl.embeddings import M, Embedding
from afsl.utils import DEFAULT_MINI_BATCH_SIZE
from afsl.model import ModelWithEmbedding
from afsl.utils import compute_embedding


class GreedyMaxDistState(NamedTuple):
embeddings: torch.Tensor
centroid_indices: List[torch.Tensor]


class GreedyMaxDist(SequentialAcquisitionFunction[M, GreedyMaxDistState]):
embedding: Embedding[M]

def __init__(
self, embedding: Embedding[M], mini_batch_size=DEFAULT_MINI_BATCH_SIZE
):
super().__init__(mini_batch_size=mini_batch_size)
self.embedding = embedding

class GreedyMaxDist(
SequentialAcquisitionFunction[ModelWithEmbedding, GreedyMaxDistState]
):
def initialize(
self,
model: M,
model: ModelWithEmbedding,
data: torch.Tensor,
) -> GreedyMaxDistState:
embeddings = self.embedding.embed(model, data)
embeddings = compute_embedding(
model, data, mini_batch_size=self.mini_batch_size
)
# Choose the first centroid randomly
centroid_indices = [
torch.randint(0, embeddings.size(0), (1,)).to(embeddings.device)
Expand Down
10 changes: 3 additions & 7 deletions afsl/active_data_loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from typing import Generic
import torch
from afsl.acquisition_functions import AcquisitionFunction
from afsl.acquisition_functions import M, AcquisitionFunction
from afsl.acquisition_functions.itl import ITL
from afsl.embeddings import M, Embedding
from afsl.embeddings.provided import ProvidedEmbedding


class ActiveDataLoader(Generic[M]):
Expand All @@ -27,7 +25,7 @@ class ActiveDataLoader(Generic[M]):
batch_size: int
r"""Size of the batch to be selected."""

acquisition_function: AcquisitionFunction
acquisition_function: AcquisitionFunction[M]
r"""Acquisition function to be used for data selection."""

subsampled_target_frac: float
Expand All @@ -37,7 +35,7 @@ def __init__(
self,
data: torch.Tensor,
batch_size: int,
acquisition_function: AcquisitionFunction,
acquisition_function: AcquisitionFunction[M],
):
assert data.size(0) > 0, "Data must be non-empty"
assert batch_size > 0, "Batch size must be positive"
Expand All @@ -52,14 +50,12 @@ def initialize(
data: torch.Tensor,
target: torch.Tensor,
batch_size: int,
embedding: Embedding = ProvidedEmbedding(),
Sigma: torch.Tensor | None = None,
subsampled_target_frac: float = 0.5,
max_target_size: int | None = None,
):
acquisition_function = ITL(
target=target,
embedding=embedding,
Sigma=Sigma,
subsampled_target_frac=subsampled_target_frac,
max_target_size=max_target_size,
Expand Down
18 changes: 0 additions & 18 deletions afsl/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +0,0 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
import torch
from afsl.model import Model
from afsl.utils import DEFAULT_MINI_BATCH_SIZE

M = TypeVar("M", bound=Model)


class Embedding(ABC, Generic[M]):
mini_batch_size: int

def __init__(self, mini_batch_size=DEFAULT_MINI_BATCH_SIZE):
self.mini_batch_size = mini_batch_size

@abstractmethod
def embed(self, model: M, data: torch.Tensor) -> torch.Tensor:
pass
Loading

0 comments on commit 70d0d19

Please sign in to comment.