Skip to content

Commit

Permalink
update metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
Samoed committed Jan 31, 2025
1 parent 9faa681 commit 51dcc7e
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 29 deletions.
10 changes: 5 additions & 5 deletions autointent/nodes/_optimization/_node_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def __init__(
"""
self.node_type = node_type
self.node_info = NODES_INFO[node_type]
self.decision_metric_name = target_metric
self.target_metric = target_metric

self.metrics = metrics if metrics is not None else []
if self.decision_metric_name not in self.metrics:
self.metrics.append(self.decision_metric_name)
if self.target_metric not in self.metrics:
self.metrics.append(self.target_metric)

self.modules_search_spaces = search_space # TODO search space validation
self._logger = logging.getLogger(__name__) # TODO solve duplicate logging messages problem
Expand Down Expand Up @@ -73,7 +73,7 @@ def fit(self, context: Context) -> None:

self._logger.debug("scoring %s module...", module_name)
metrics_score = module.score(context, "validation", self.metrics)
metric_value = metrics_score[self.decision_metric_name]
metric_value = metrics_score[self.target_metric]

context.callback_handler.log_metrics(metrics_score)
context.callback_handler.end_module()
Expand All @@ -91,7 +91,7 @@ def fit(self, context: Context) -> None:
module_name,
module_kwargs,
metric_value,
self.decision_metric_name,
self.target_metric,
module.get_assets(), # retriever name / scores / predictions
module_dump_dir,
module=module if not context.is_ram_to_clear() else None,
Expand Down
12 changes: 8 additions & 4 deletions autointent/nodes/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class DecisionNodeValidator(BaseModel):
"""Search space configuration for the Decision node."""

node_type: NodeType = NodeType.decision
metric: DecisionMetrics
target_metric: DecisionMetrics
metrics: list[DecisionMetrics] | None = None
search_space: list[DecisionSearchSpaceType]


Expand All @@ -70,7 +71,8 @@ class EmbeddingNodeValidator(BaseModel):
"""Search space configuration for the Embedding node."""

node_type: NodeType = NodeType.embedding
metric: EmbeddingMetrics
target_metric: EmbeddingMetrics
metrics: list[EmbeddingMetrics] | None = None
search_space: list[EmbeddingSearchSpaceType]


Expand All @@ -84,7 +86,8 @@ class ScoringNodeValidator(BaseModel):
"""Search space configuration for the Scoring node."""

node_type: NodeType = NodeType.scoring
metric: ScoringMetrics
target_metric: ScoringMetrics
metrics: list[ScoringMetrics] | None = None
search_space: list[ScoringSearchSpaceType]


Expand All @@ -98,7 +101,8 @@ class RegexNodeValidator(BaseModel):
"""Search space configuration for the Regexp node."""

node_type: NodeType = NodeType.regexp
metric: RegexpMetrics
target_metric: RegexpMetrics
metrics: list[RegexpMetrics] | None = None
search_space: list[RegexpSearchSpaceType]


Expand Down
8 changes: 4 additions & 4 deletions tests/configs/test_combined_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def valid_optimizer_config():
return [
{
"node_type": "scoring",
"metric": "scoring_roc_auc",
"target_metric": "scoring_roc_auc",
"search_space": [
{
"module_name": "dnnc",
Expand All @@ -28,7 +28,7 @@ def valid_optimizer_config():
},
{
"node_type": "embedding",
"metric": "retrieval_hit_rate",
"target_target_metric": "retrieval_hit_rate",
"search_space": [
{
"module_name": "retrieval",
Expand Down Expand Up @@ -62,7 +62,7 @@ def test_invalid_optimizer_config_missing_field():
invalid_config = [
{
"node_type": "scoring",
# Missing "metric"
# Missing "target_metric"
"search_space": [
{"module_name": "dnnc", "cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"], "k": [1, 3]}
],
Expand All @@ -78,7 +78,7 @@ def test_invalid_optimizer_config_wrong_type():
invalid_config = [
{
"node_type": "scoring",
"metric": "scoring_roc_auc",
"target_metric": "scoring_roc_auc",
"search_space": [
{
"module_name": "dnnc",
Expand Down
8 changes: 4 additions & 4 deletions tests/configs/test_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def valid_decision_config():
return [
{
"node_type": "decision",
"metric": "decision_roc_auc",
"target_metric": "decision_roc_auc",
"search_space": [
{"module_name": "argmax"},
{"module_name": "jinoos", "search_space": [[0.3, 0.5, 0.7]]},
Expand All @@ -29,7 +29,7 @@ def test_valid_decision_config(valid_decision_config):
"""Test that a valid decision config passes validation."""
config = OptimizationConfig(valid_decision_config)
assert config[0].node_type == "decision"
assert config[0].metric == "decision_roc_auc"
assert config[0].target_metric == "decision_roc_auc"
assert isinstance(config[0].search_space, list)
assert config[0].search_space[0].module_name == "argmax"

Expand All @@ -39,7 +39,7 @@ def test_invalid_decision_config_missing_field():
invalid_config = [
{
"node_type": "decision",
# Missing "metric"
# Missing "target_metric"
"search_space": [{"module_name": "tunable", "n_trials": [100]}],
}
]
Expand All @@ -53,7 +53,7 @@ def test_invalid_decision_config_wrong_type():
invalid_config = [
{
"node_type": "decision",
"metric": "decision_roc_auc",
"target_metric": "decision_roc_auc",
"search_space": [
{
"module_name": "threshold",
Expand Down
8 changes: 4 additions & 4 deletions tests/configs/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def valid_embedding_config():
return [
{
"node_type": "embedding",
"metric": "retrieval_mrr",
"target_metric": "retrieval_mrr",
"search_space": [
{"module_name": "logreg_embedding", "embedder_name": ["sergeyzh/rubert-tiny-turbo"], "cv": [3, 5]},
{"module_name": "retrieval", "embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"], "k": [5, 10]},
Expand All @@ -23,7 +23,7 @@ def test_valid_embedding_config(valid_embedding_config):
"""Test that a valid embedding config passes validation."""
config = OptimizationConfig(valid_embedding_config)
assert config[0].node_type == "embedding"
assert config[0].metric == "retrieval_mrr"
assert config[0].target_metric == "retrieval_mrr"
assert isinstance(config[0].search_space, list)
assert config[0].search_space[0].module_name == "logreg_embedding"
assert "embedder_name" in config[0].search_space[0].model_dump()
Expand All @@ -34,7 +34,7 @@ def test_invalid_embedding_config_missing_field():
invalid_config = [
{
"node_type": "embedding",
# Missing "metric"
# Missing "target_metric"
"search_space": [
{"module_name": "retrieval", "embedder_name": ["sentence-transformers/all-MiniLM-L6-v2"], "k": [5, 10]}
],
Expand All @@ -50,7 +50,7 @@ def test_invalid_embedding_config_wrong_type():
invalid_config = [
{
"node_type": "embedding",
"metric": "retrieval_mrr",
"target_metric": "retrieval_mrr",
"search_space": [
{
"module_name": "logreg_embedding",
Expand Down
8 changes: 4 additions & 4 deletions tests/configs/test_regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
@pytest.fixture
def valid_regexp_config():
"""Fixture for a valid RegExp node configuration."""
return [{"node_type": "regexp", "metric": "regexp_partial_accuracy", "search_space": [{"module_name": "regexp"}]}]
return [{"node_type": "regexp", "target_metric": "regexp_partial_accuracy", "search_space": [{"module_name": "regexp"}]}]


def test_valid_regexp_config(valid_regexp_config):
"""Test that a valid RegExp config passes validation."""
config = OptimizationConfig(valid_regexp_config)
assert config[0].node_type == "regexp"
assert config[0].metric == "regexp_partial_accuracy"
assert config[0].target_metric == "regexp_partial_accuracy"
assert isinstance(config[0].search_space, list)
assert config[0].search_space[0].module_name == "regexp"

Expand All @@ -23,7 +23,7 @@ def test_invalid_regexp_config_missing_field():
"""Test that a missing required field raises ValidationError."""
invalid_config = {
"node_type": "regexp",
# Missing "metric"
# Missing "target_metric"
"search_space": [{"module_name": "regexp"}],
}

Expand All @@ -35,7 +35,7 @@ def test_invalid_regexp_config_wrong_type():
"""Test that an invalid field type raises ValidationError."""
invalid_config = {
"node_type": "regexp",
"metric": "regexp_partial_accuracy",
"target_metric": "regexp_partial_accuracy",
"search_space": "should_be_a_list", # Should be a list of dicts
}

Expand Down
8 changes: 4 additions & 4 deletions tests/configs/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def valid_scoring_config():
return [
{
"node_type": "scoring",
"metric": "scoring_roc_auc",
"target_metric": "scoring_roc_auc",
"search_space": [
{
"module_name": "dnnc",
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_valid_scoring_config(valid_scoring_config):
"""Test that a valid scoring config passes validation."""
config = OptimizationConfig(valid_scoring_config)
assert config[0].node_type == "scoring"
assert config[0].metric == "scoring_roc_auc"
assert config[0].target_metric == "scoring_roc_auc"
assert isinstance(config[0].search_space, list)
assert config[0].search_space[0].module_name == "dnnc"

Expand All @@ -70,7 +70,7 @@ def test_invalid_scoring_config_missing_field():
"""Test that a missing required field raises ValidationError."""
invalid_config = {
"node_type": "scoring",
# Missing "metric"
# Missing "target_metric"
"search_space": [
{"module_name": "dnnc", "cross_encoder_name": ["cross-encoder/ms-marco-MiniLM-L-6-v2"], "k": [5, 10]}
],
Expand All @@ -84,7 +84,7 @@ def test_invalid_scoring_config_wrong_type():
"""Test that an invalid field type raises ValidationError."""
invalid_config = {
"node_type": "scoring",
"metric": "scoring_roc_auc",
"target_metric": "scoring_roc_auc",
"search_space": [
{
"module_name": "knn",
Expand Down

0 comments on commit 51dcc7e

Please sign in to comment.