Skip to content

Commit

Permalink
fix error when exporting albert model
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Jan 11, 2020
1 parent ae3c63d commit 0575b07
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
5 changes: 5 additions & 0 deletions hanlp/losses/sparse_categorical_crossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ def __call__(self, y_true, y_pred, sample_weight=None, **kwargs):

@hanlp_register
class SparseCategoricalCrossentropyOverBatchFirstDim(object):

def __init__(self) -> None:
super().__init__()
self.__name__ = 'sparse_categorical_crossentropy_over_batch_first_dim'

def __call__(self, y_true, y_pred, sample_weight=None, **kwargs):
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
if sample_weight is not None:
Expand Down
4 changes: 2 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# Date: 2019-06-13 23:43
import os

from tests.resources import project_root
root = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))


def cdroot():
"""
cd to project root, so models are saved in the root folder
"""
os.chdir(project_root)
os.chdir(root)
5 changes: 2 additions & 3 deletions tests/train/zh/train_msra_ner_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
# Author: hankcs
# Date: 2019-12-28 23:15
from hanlp.components.ner import TransformerNamedEntityRecognizer
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.datasets.ner.msra import MSRA_NER_TRAIN, MSRA_NER_VALID, MSRA_NER_TEST
from tests import cdroot

cdroot()
recognizer = TransformerNamedEntityRecognizer()
save_dir = 'data/model/ner/ner_albert_large_msra'
recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir, transformer='albert_large_zh',
save_dir = 'data/model/ner/ner_albert_base_zh_msra'
recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir, transformer='albert_base_zh',
learning_rate=5e-5,
metrics='f1')
recognizer.load(save_dir)
Expand Down

0 comments on commit 0575b07

Please sign in to comment.