diff --git a/.gitignore b/.gitignore index 8bc1d55..f5936ee 100644 --- a/.gitignore +++ b/.gitignore @@ -101,3 +101,6 @@ venv.bak/ # mypy .mypy_cache/ + +# Models +/predictwhens \ No newline at end of file diff --git a/models/ner.py b/models/ner.py index 36f9c3f..801a9f9 100644 --- a/models/ner.py +++ b/models/ner.py @@ -29,10 +29,10 @@ logger = logging.getLogger(__name__) - -ALL_MODELS = sum( - [list(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)], - []) +# 'pretrained_config_archive_map' is deprecated and variable not actually in use +# ALL_MODELS = sum( +# [list(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)], +# []) MODEL_CLASSES = { 'bert': (BertConfig, BertForTokenClassification, BertTokenizer), @@ -407,14 +407,23 @@ def predict(self, tasks, **kwargs): result = [] + # bool for when 'O'-label is not appended in 'result' + skipped = False + for label, group in groupby(zip(preds, starts, scores), key=lambda i: re.sub('^(B-|I-)', '', i[0])): _, group_start, _ = list(group)[0] if len(result) > 0: if group_start == 0: result.pop(-1) - else: + # when 'O' is skipped when appending 'result', 'end' of previous 'result' sould not be changed, until new label is appended + elif not skipped: result[-1]['value']['end'] = group_start - 1 + # remove incorrect predictions, where 'end is smaller than 'start' + if result[-1]['value']['end'] is not None and result[-1]['value']['end'] < result[-1]['value']['start']: + result.pop(-1) + if label != 'O': + skipped = False result.append({ 'from_name': from_name, 'to_name': to_name, @@ -426,6 +435,9 @@ def predict(self, tasks, **kwargs): 'text': '...' } }) + else: + skipped = True + if result and result[-1]['value']['end'] is None: result[-1]['value']['end'] = len(string) results.append({