Skip to content

Commit

Permalink
More results testing funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
jlumpe committed Aug 14, 2024
1 parent 90fe8db commit 96b5edd
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 46 deletions.
41 changes: 2 additions & 39 deletions tests/data/testdb_210818/generate-results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pathlib import Path

from gambit.seq import SequenceFile
from gambit.db import reportable_taxon
from gambit.query import QueryParams, QueryResults, query_parse
from gambit.results import ResultsArchiveWriter
from gambit.util.misc import zip_strict
Expand All @@ -22,6 +21,7 @@

sys.path.insert(0, str(ROOTDIR))
from tests.testdb import TestDB, TestQueryGenome
from tests.results import check_results as check_results_base


PARAMS = {
Expand All @@ -36,17 +36,12 @@ def check_results(queries: list[TestQueryGenome], query_files: list[SequenceFile
strict = results.params.classify_strict

for query, query_file, item in zip_strict(queries, query_files, results.items):
warnings = []

clsresult = item.classifier_result
predicted = clsresult.predicted_taxon

assert item.input.file == query_file

# No errors
assert clsresult.success
assert clsresult.error is None

# Check if warnings expected (only if in strict mode)
assert bool(clsresult.warnings) == (strict and query['warnings'])

Expand All @@ -62,44 +57,11 @@ def check_results(queries: list[TestQueryGenome], query_files: list[SequenceFile
assert predicted.name == query['predicted']
assert clsresult.primary_match.genome.description == query['primary']

else:
assert clsresult.primary_match == clsresult.closest_match
assert predicted is clsresult.primary_match.matched_taxon

assert item.report_taxon is reportable_taxon(predicted)

else:
assert predicted is None
assert clsresult.primary_match is None
assert item.report_taxon is None

# Closest matches
assert len(item.closest_genomes) == results.params.report_closest
assert item.closest_genomes[0] == clsresult.closest_match

for i in range(1, results.params.report_closest):
assert item.closest_genomes[i].distance >= item.closest_genomes[i-1].distance

# Next taxon
nt = clsresult.next_taxon
if nt is None:
# Predicted should be most specific possible
assert clsresult.closest_match.matched_taxon == clsresult.closest_match.genome.taxon

else:
assert nt.distance_threshold is not None
assert nt.distance_threshold < clsresult.closest_match.distance

# This should hold true as long as the primary match is the closest match, just warn if
# it fails.
if predicted is not None:
if predicted not in nt.ancestors():
warnings.append(f'Next taxon {nt.name} not a descendant of predicted taxon {predicted.name}')

# Display warnings
for w in warnings:
print(f'[Query "{query["name"]}"]:', w, file=sys.stderr)


def main():
testdb = TestDB(THISDIR)
Expand All @@ -111,6 +73,7 @@ def main():
for label, params in PARAMS.items():
print('Running query:', label)
results = query_parse(db, query_files, params)
check_results_base(results)
check_results(testdb.query_genomes, query_files, results)

with open(f'results/{label}.json', 'wt') as f:
Expand Down
73 changes: 68 additions & 5 deletions tests/results.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,80 @@
"""Funcs for testing exported data."""
"""Helper code for tests related to the QueryResults class or exported result data."""

import csv
import json
from typing import TextIO, Any, Iterable, Optional
from pathlib import Path
from warnings import warn

import numpy as np

from gambit.util.json import to_json
from gambit.query import QueryResults, QueryResultItem
from gambit.query import QueryResults, QueryResultItem, QueryParams
from gambit.classify import GenomeMatch, ClassifierResult
from gambit.util.misc import zip_strict
from gambit.db.models import AnnotatedGenome, Taxon
from gambit.db.models import AnnotatedGenome, Taxon, reportable_taxon


def check_results(results: QueryResults, warnings: bool = True):
"""Check invariants on query results object."""

assert results.params is not None

for item in results.items:
check_result_item(item, results.params, warnings=warnings)


def check_result_item(item: QueryResultItem, params: QueryParams, warnings: bool = True):
"""Check invariants on successful query result item."""

clsresult = item.classifier_result
predicted = clsresult.predicted_taxon

# No errors
assert clsresult.success
assert clsresult.error is None

# Predicted taxon
if predicted is not None:
assert clsresult.primary_match is not None

if not params.classify_strict:
assert clsresult.primary_match == clsresult.closest_match
assert predicted is clsresult.primary_match.matched_taxon

assert item.report_taxon is reportable_taxon(predicted)

else:
assert clsresult.primary_match is None
assert item.report_taxon is None

# Closest matches
assert len(item.closest_genomes) == params.report_closest
assert item.closest_genomes[0] == clsresult.closest_match

# Check closest_genomes is sorted by distance
for i in range(1, params.report_closest):
assert item.closest_genomes[i].distance >= item.closest_genomes[i-1].distance

# Next taxon
nt = clsresult.next_taxon
if nt is None:
# Predicted should be most specific possible
assert clsresult.closest_match.matched_taxon == clsresult.closest_match.genome.taxon

else:
assert nt.distance_threshold is not None
assert nt.distance_threshold < clsresult.closest_match.distance

# This should hold true as long as the primary match is the closest match, just warn if
# it fails.
if predicted is not None:
if predicted not in nt.ancestors():
if warnings:
warn(
f'[Query {item.input.label}]: '
f'next taxon {nt.name} not a descendant of predicted taxon {predicted.name}'
)


def compare_genome_matches(match1: Optional[GenomeMatch], match2: Optional[GenomeMatch]):
Expand All @@ -32,7 +95,7 @@ def compare_genome_matches(match1: Optional[GenomeMatch], match2: Optional[Genom
assert np.isclose(match1.distance, match2.distance)


def compare_classifier_results(result1: ClassifierResult, result2: ClassifierResult) -> bool:
def compare_classifier_results(result1: ClassifierResult, result2: ClassifierResult):
"""Assert two ``ClassifierResult`` instances are equal."""
assert result1.success == result2.success
assert result1.predicted_taxon == result2.predicted_taxon
Expand All @@ -43,7 +106,7 @@ def compare_classifier_results(result1: ClassifierResult, result2: ClassifierRes
assert result1.error == result2.error


def compare_result_items(item1: QueryResultItem, item2: QueryResultItem) -> bool:
def compare_result_items(item1: QueryResultItem, item2: QueryResultItem):
"""Assert two ``QueryResultItem`` instances are equal.
Does not compare the value of the ``input`` attributes.
Expand Down
8 changes: 6 additions & 2 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gambit import __version__ as GAMBIT_VERSION

from .testdb import TestDB
from .results import compare_result_items
from .results import compare_result_items, check_results


class TestQueryInput:
Expand All @@ -32,10 +32,14 @@ class TestQuery:

def check_results(self, results: QueryResults, ref_results: QueryResults):

# Check general invariants of QueryResults object
check_results(results, warnings=False) # One of the queries is designed to generate a warning
assert results.gambit_version == GAMBIT_VERSION

# Check matches reference results
assert results.params == ref_results.params
assert results.genomeset == ref_results.genomeset
assert results.signaturesmeta == ref_results.signaturesmeta
assert results.gambit_version == GAMBIT_VERSION

for item, ref_item in zip_strict(results.items, ref_results.items):
compare_result_items(item, ref_item)
Expand Down

0 comments on commit 96b5edd

Please sign in to comment.