Skip to content

Commit

Permalink
BigBird added for NER and classification
Browse files Browse the repository at this point in the history
  • Loading branch information
Thilina Rajapakse committed May 28, 2021
1 parent 0ed7a75 commit d6b9acc
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
16 changes: 14 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.61.5] - 2021-05-18
## [0.61.6] - 2021-05-28

### Fixed

- Fixed the onnx predict loop [whr778](https://github.com/whr778)

### Added

- Added NER support for BigBird, Deberta, Deberta-v2, and xlm pretrained models [whr778](https://github.com/whr778)
- Added BigBird for regular sequence classification (not multilabel) [@manueltonneau](https://github.com/manueltonneau)

## [0.61.5] - 2021-05-18
### Added

- Fixed possible bug when using HF Datasets with Seq2SeqModel
- Added `repo: simpletransformers` to W&B config

Expand Down Expand Up @@ -1504,7 +1514,9 @@ Model checkpoint is now saved for all epochs again.

- This CHANGELOG file to hopefully serve as an evolving example of a standardized open source project CHANGELOG.

[0.61.5]: https://github.com/ThilinaRajapakse/simpletransformers/compare/b49bf28...HEAD
[0.61.6]: https://github.com/ThilinaRajapakse/simpletransformers/compare/281ff31...HEAD

[0.61.5]: https://github.com/ThilinaRajapakse/simpletransformers/compare/b49bf28...281ff31

[0.61.4]: https://github.com/ThilinaRajapakse/simpletransformers/compare/87eeb0e...b49bf28

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="simpletransformers",
version="0.61.5",
version="0.61.6",
author="Thilina Rajapakse",
author_email="[email protected]",
description="An easy-to-use wrapper library for the Transformers library.",
Expand Down
23 changes: 20 additions & 3 deletions tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"model_type, model_name",
[
("bert", "bert-base-uncased"),
("bigbird", "google/bigbird-roberta-base"),
# ("longformer", "allenai/longformer-base-4096"),
# ("electra", "google/electra-small-discriminator"),
# ("mobilebert", "google/mobilebert-uncased"),
Expand Down Expand Up @@ -70,6 +71,7 @@ def test_binary_classification(model_type, model_name):
[
# ("bert", "bert-base-uncased"),
# ("xlnet", "xlnet-base-cased"),
("bigbird", "google/bigbird-roberta-base"),
# ("xlm", "xlm-mlm-17-1280"),
("roberta", "roberta-base"),
# ("distilbert", "distilbert-base-uncased"),
Expand Down Expand Up @@ -175,7 +177,22 @@ def test_multilabel_classification(model_type, model_name):
)


def test_sliding_window():
@pytest.mark.parametrize(
"model_type, model_name",
[
# ("bert", "bert-base-uncased"),
# ("xlnet", "xlnet-base-cased"),
("bigbird", "google/bigbird-roberta-base"),
# ("xlm", "xlm-mlm-17-1280"),
("roberta", "roberta-base"),
# ("distilbert", "distilbert-base-uncased"),
# ("albert", "albert-base-v1"),
# ("camembert", "camembert-base"),
# ("xlmroberta", "xlm-roberta-base"),
# ("flaubert", "flaubert-base-cased"),
],
)
def test_sliding_window(model_type, model_name):
# Train and Evaluation data needs to be in a Pandas Dataframe of two columns.
# The first column is the text with type str, and the second column is the
# label with type int.
Expand All @@ -193,8 +210,8 @@ def test_sliding_window():

# Create a ClassificationModel
model = ClassificationModel(
"distilbert",
"distilbert-base-uncased",
model_type,
model_name,
use_cuda=False,
args={
"no_save": True,
Expand Down
5 changes: 3 additions & 2 deletions tests/test_named_entity_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
"model_type, model_name",
[
("bert", "bert-base-uncased"),
# ("longformer", "allenai/longformer-base-4096"),
("bigbird", "google/bigbird-roberta-base"),
("longformer", "allenai/longformer-base-4096"),
# ("xlnet", "xlnet-base-cased"),
# ("xlm", "xlm-mlm-17-1280"),
# ("roberta", "roberta-base"),
("roberta", "roberta-base"),
# ("distilbert", "distilbert-base-uncased"),
# ("albert", "albert-base-v1"),
# ("camembert", "camembert-base"),
Expand Down

0 comments on commit d6b9acc

Please sign in to comment.