forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add German RoBERTa model (GottBERT) (facebookresearch#2992)
Summary: # Before submitting - There is no related issue for this pull request. - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - We did not see any necessity for tests. ## What does this PR do? Add German RoBERTa model (GottBERT) Pull Request resolved: facebookresearch#2992 Reviewed By: alexeib Differential Revision: D25494927 Pulled By: myleott fbshipit-source-id: b6790124d7c3c8dc387c141706cd8a527cc950ab
- Loading branch information
1 parent
032a404
commit f3d5045
Showing
7 changed files
with
124 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# GottBERT: a pure German language model | ||
|
||
## Introduction | ||
|
||
[GottBERT](http://arxiv.org/abs/2012.02110) is a pretrained language model trained on 145GB of German text based on RoBERTa. | ||
|
||
## Example usage | ||
|
||
### fairseq | ||
##### Load GottBERT from torch.hub (PyTorch >= 1.1): | ||
```python | ||
import torch | ||
gottbert = torch.hub.load('pytorch/fairseq', 'gottbert-base') | ||
gottbert.eval() # disable dropout (or leave in train mode to finetune) | ||
``` | ||
|
||
##### Load GottBERT (for PyTorch 1.0 or custom models): | ||
```python | ||
# Download gottbert model | ||
wget https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz | ||
tar -xzvf gottbert.tar.gz | ||
|
||
# Load the model in fairseq | ||
from fairseq.models.roberta import GottbertModel | ||
gottbert = GottbertModel.from_pretrained('/path/to/gottbert') | ||
gottbert.eval() # disable dropout (or leave in train mode to finetune) | ||
``` | ||
|
||
##### Filling masks: | ||
```python | ||
masked_line = 'Gott ist <mask> ! :)' | ||
gottbert.fill_mask(masked_line, topk=3) | ||
# [('Gott ist gut ! :)', 0.3642110526561737, ' gut'), | ||
# ('Gott ist überall ! :)', 0.06009674072265625, ' überall'), | ||
# ('Gott ist großartig ! :)', 0.0370681993663311, ' großartig')] | ||
``` | ||
|
||
##### Extract features from GottBERT | ||
|
||
```python | ||
# Extract the last layer's features | ||
line = "Der erste Schluck aus dem Becher der Naturwissenschaft macht atheistisch , aber auf dem Grunde des Bechers wartet Gott !" | ||
tokens = gottbert.encode(line) | ||
last_layer_features = gottbert.extract_features(tokens) | ||
assert last_layer_features.size() == torch.Size([1, 27, 768]) | ||
|
||
# Extract all layer's features (layer 0 is the embedding layer) | ||
all_layers = gottbert.extract_features(tokens, return_all_hiddens=True) | ||
assert len(all_layers) == 13 | ||
assert torch.all(all_layers[-1] == last_layer_features) | ||
``` | ||
## Citation | ||
If you use our work, please cite: | ||
|
||
```bibtex | ||
@misc{scheible2020gottbert, | ||
title={GottBERT: a pure German Language Model}, | ||
author={Raphael Scheible and Fabian Thomczyk and Patric Tippmann and Victor Jaravine and Martin Boeker}, | ||
year={2020}, | ||
eprint={2012.02110}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CL} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
GottBERT: a pure German Language Model | ||
""" | ||
|
||
from fairseq.models import register_model | ||
|
||
from .hub_interface import RobertaHubInterface | ||
from .model import RobertaModel | ||
|
||
|
||
@register_model('gottbert') | ||
class GottbertModel(RobertaModel): | ||
|
||
@classmethod | ||
def hub_models(cls): | ||
return { | ||
'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz', | ||
} | ||
|
||
@classmethod | ||
def from_pretrained(cls, | ||
model_name_or_path, | ||
checkpoint_file='model.pt', | ||
data_name_or_path='.', | ||
bpe='hf_byte_bpe', | ||
bpe_vocab='vocab.json', | ||
bpe_merges='merges.txt', | ||
bpe_add_prefix_space=False, | ||
**kwargs | ||
): | ||
from fairseq import hub_utils | ||
|
||
x = hub_utils.from_pretrained( | ||
model_name_or_path, | ||
checkpoint_file, | ||
data_name_or_path, | ||
archive_map=cls.hub_models(), | ||
bpe=bpe, | ||
load_checkpoint_heads=True, | ||
bpe_vocab=bpe_vocab, | ||
bpe_merges=bpe_merges, | ||
bpe_add_prefix_space=bpe_add_prefix_space, | ||
**kwargs, | ||
) | ||
return RobertaHubInterface(x['args'], x['task'], x['models'][0]) |