Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
voorhs committed Jan 27, 2025
1 parent 1e270c6 commit cf5fdef
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 16 deletions.
8 changes: 0 additions & 8 deletions autointent/modules/abc/_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,3 @@

class EmbeddingModule(Module, ABC):
"""Base class for embedding modules."""

def __init__(self, k: int) -> None:
"""
Initialize embedding module.
:param k: number of closest neighbors to consider during inference
"""
self.k = k
9 changes: 3 additions & 6 deletions autointent/modules/embedding/_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,12 @@ class LogregAimedEmbedding(EmbeddingModule):
_classifier: LogisticRegressionCV | MultiOutputClassifier
_label_encoder: LabelEncoder | None
name = "logreg"
supports_multiclass = True
supports_multilabel = True
supports_oos = False

def __init__(
self,
k: int,
embedder_name: str,
cv: int = 3,
embedder_device: str = "cpu",
Expand All @@ -59,7 +61,6 @@ def __init__(
Initialize the LogregAimedEmbedding.
:param cv: the number of folds used in LogisticRegressionCV
:param k: Number of nearest neighbors to retrieve.
:param embedder_name: Name of the embedder used for creating embeddings.
:param embedder_device: Device to run operations on, e.g., "cpu" or "cuda".
:param batch_size: Batch size for embedding generation.
Expand All @@ -73,8 +74,6 @@ def __init__(
self.embedder_use_cache = embedder_use_cache
self.cv = cv

super().__init__(k=k)

@classmethod
def from_context(
cls,
Expand All @@ -88,12 +87,10 @@ def from_context(
:param cv: the number of folds used in LogisticRegressionCV
:param context: The context containing configurations and utilities.
:param k: Number of nearest neighbors to retrieve.
:param embedder_name: Name of the embedder to use.
:return: Initialized LogregAimedEmbedding instance.
"""
return cls(
k=k,
cv=cv,
embedder_name=embedder_name,
embedder_device=context.get_device(),
Expand Down
6 changes: 4 additions & 2 deletions autointent/modules/embedding/_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class RetrievalAimedEmbedding(EmbeddingModule):

_vector_index: VectorIndex
name = "retrieval"
supports_multiclass = True
supports_multilabel = True
supports_oos = False

def __init__(
self,
Expand All @@ -57,14 +60,13 @@ def __init__(
:param max_length: Maximum sequence length for embeddings. None if not set.
:param embedder_use_cache: Flag indicating whether to cache intermediate embeddings.
"""
self.k = k
self.embedder_name = embedder_name
self.embedder_device = embedder_device
self.embedder_batch_size = embedder_batch_size
self.embedder_max_length = embedder_max_length
self.embedder_use_cache = embedder_use_cache

super().__init__(k=k)

@classmethod
def from_context(
cls,
Expand Down

0 comments on commit cf5fdef

Please sign in to comment.