Skip to content

Commit

Permalink
feat: updated test
Browse files Browse the repository at this point in the history
  • Loading branch information
Darinochka committed Jan 27, 2025
1 parent cd75d72 commit 5d4a8c9
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions tests/modules/retrieval/test_logreg.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import MagicMock

from autointent.modules.embedding import LogRegEmbedding


Expand All @@ -21,17 +19,15 @@ def test_fit_trains_model():
assert module._label_encoder.classes_.tolist() == [0, 1]


def test_score_evaluates_model():
def test_predict_evaluates_model():
module = LogRegEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")

utterances = ["hello", "goodbye", "hi", "bye", "bye", "hello", "welcome", "hi123", "hiii", "bye-bye", "bye!"]
labels = [0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1]
module.fit(utterances, labels)

mock_context = MagicMock()
mock_context.data_handler.test_utterances.return_value = ["hello", "goodbye"]
mock_context.data_handler.test_labels.return_value = [[1, 0], [0, 1]]

scores = module.score(mock_context, split="test")
probas = module.predict(["hello", "bye"])

assert isinstance(scores, dict)
assert len(probas) == 2
assert probas[0][0] > probas[0][1]
assert probas[1][1] > probas[1][0]

0 comments on commit 5d4a8c9

Please sign in to comment.