From 0575b0792bbd92e36e42f6d118685a8a71a38855 Mon Sep 17 00:00:00 2001 From: hankcs Date: Sat, 11 Jan 2020 17:57:21 -0500 Subject: [PATCH] fix error when exporting albert model --- hanlp/losses/sparse_categorical_crossentropy.py | 5 +++++ tests/__init__.py | 4 ++-- tests/train/zh/train_msra_ner_albert.py | 5 ++--- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/hanlp/losses/sparse_categorical_crossentropy.py b/hanlp/losses/sparse_categorical_crossentropy.py index cac83c9a0..22d980860 100644 --- a/hanlp/losses/sparse_categorical_crossentropy.py +++ b/hanlp/losses/sparse_categorical_crossentropy.py @@ -24,6 +24,11 @@ def __call__(self, y_true, y_pred, sample_weight=None, **kwargs): @hanlp_register class SparseCategoricalCrossentropyOverBatchFirstDim(object): + + def __init__(self) -> None: + super().__init__() + self.__name__ = 'sparse_categorical_crossentropy_over_batch_first_dim' + def __call__(self, y_true, y_pred, sample_weight=None, **kwargs): loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True) if sample_weight is not None: diff --git a/tests/__init__.py b/tests/__init__.py index cec81ef03..4d7294bb0 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,11 +3,11 @@ # Date: 2019-06-13 23:43 import os -from tests.resources import project_root +root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) def cdroot(): """ cd to project root, so models are saved in the root folder """ - os.chdir(project_root) + os.chdir(root) diff --git a/tests/train/zh/train_msra_ner_albert.py b/tests/train/zh/train_msra_ner_albert.py index f2e8cb68b..77abe34a2 100644 --- a/tests/train/zh/train_msra_ner_albert.py +++ b/tests/train/zh/train_msra_ner_albert.py @@ -2,14 +2,13 @@ # Author: hankcs # Date: 2019-12-28 23:15 from hanlp.components.ner import TransformerNamedEntityRecognizer -from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger from hanlp.datasets.ner.msra import MSRA_NER_TRAIN, MSRA_NER_VALID, MSRA_NER_TEST from tests import cdroot cdroot() recognizer = TransformerNamedEntityRecognizer() -save_dir = 'data/model/ner/ner_albert_large_msra' -recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir, transformer='albert_large_zh', +save_dir = 'data/model/ner/ner_albert_base_zh_msra' +recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir, transformer='albert_base_zh', learning_rate=5e-5, metrics='f1') recognizer.load(save_dir)