From 8d036c2fe01be5158c3ae5265d32c619131d8783 Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Sun, 28 Jul 2019 18:40:05 -0700 Subject: [PATCH] Add RoBERTa Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/916 Differential Revision: D16537774 Pulled By: myleott fbshipit-source-id: 86bb7b1913a428ee4a21674cc3fc7b39264067ec --- README.md | 30 +++-- examples/roberta/README.md | 3 +- fairseq/models/__init__.py | 4 +- .../models/{roberta.py => roberta/model.py} | 109 +----------------- 4 files changed, 25 insertions(+), 121 deletions(-) rename fairseq/models/{roberta.py => roberta/model.py} (73%) diff --git a/README.md b/README.md index 42253940d3..755e5c8fd0 100644 --- a/README.md +++ b/README.md @@ -1,9 +1,18 @@ -# Introduction +# Introduction Fairseq(-py) is a sequence modeling toolkit that allows researchers and developers to train custom models for translation, summarization, language -modeling and other text generation tasks. It provides reference implementations -of various sequence-to-sequence models, including: +modeling and other text generation tasks. + +### What's New: + +- July 2019: [RoBERTa models and code release](examples/roberta/README.md) +- June 2019: [wav2vec models and code release](examples/wav2vec/README.md) +- April 2019: [fairseq demo paper @ NAACL 2019](https://arxiv.org/abs/1904.01038) + +### Features: + +Fairseq provides reference implementations of various sequence-to-sequence models, including: - **Convolutional Neural Networks (CNN)** - [Dauphin et al. (2017): Language Modeling with Gated Convolutional Networks](examples/language_model/conv_lm/README.md) - [Gehring et al. (2017): Convolutional Sequence to Sequence Learning](examples/conv_seq2seq/README.md) @@ -11,18 +20,18 @@ of various sequence-to-sequence models, including: - [Fan et al. (2018): Hierarchical Neural Story Generation](examples/stories/README.md) - **_New_** [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](examples/wav2vec/README.md) - **LightConv and DynamicConv models** - - **_New_** [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md) + - [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md) - **Long Short-Term Memory (LSTM) networks** - - [Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation](https://arxiv.org/abs/1508.04025) - - [Wiseman and Rush (2016): Sequence-to-Sequence Learning as Beam-Search Optimization](https://arxiv.org/abs/1606.02960) + - Luong et al. (2015): Effective Approaches to Attention-based Neural Machine Translation - **Transformer (self-attention) networks** - - [Vaswani et al. (2017): Attention Is All You Need](https://arxiv.org/abs/1706.03762) + - Vaswani et al. (2017): Attention Is All You Need - [Ott et al. (2018): Scaling Neural Machine Translation](examples/scaling_nmt/README.md) - [Edunov et al. (2018): Understanding Back-Translation at Scale](examples/backtranslation/README.md) - - **_New_** [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md) - - **_New_** [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md) + - [Baevski and Auli (2018): Adaptive Input Representations for Neural Language Modeling](examples/language_model/transformer_lm/README.md) + - [Shen et al. (2019): Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md) + - **_New_** [Liu et al. (2019): RoBERTa: A Robustly Optimized BERT Pretraining Approach](examples/roberta/README.md) -Fairseq features: +**Additionally:** - multi-GPU (distributed) training on one machine or across multiple machines - fast generation on both CPU and GPU with multiple search algorithms implemented: - beam search @@ -83,6 +92,7 @@ as well as example training and evaluation commands. - [Language Modeling](examples/language_model/README.md): convolutional models are available We also have more detailed READMEs to reproduce results from specific papers: +- [Liu et al. (2019): RoBERTa: A Robustly Optimized BERT Pretraining Approach](examples/roberta/README.md) - [Schneider et al. (2019): wav2vec: Unsupervised Pre-training for Speech Recognition](examples/wav2vec/README.md) - [Shen et al. (2019) Mixture Models for Diverse Machine Translation: Tricks of the Trade](examples/translation_moe/README.md) - [Wu et al. (2019): Pay Less Attention with Lightweight and Dynamic Convolutions](examples/pay_less_attention_paper/README.md) diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 52550157f3..3e757e9289 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -1,6 +1,6 @@ # RoBERTa: A Robustly Optimized BERT Pretraining Approach -*Pre-print coming 7/28* +https://arxiv.org/abs/1907.11692 ## Introduction @@ -144,6 +144,7 @@ A more detailed tutorial is coming soon. author = {Yinhan Liu and Myle Ott and Naman Goyal and Jingfei Du and Mandar Joshi and Danqi Chen and Omer Levy and Mike Lewis and Luke Zettlemoyer and Veselin Stoyanov}, + journal={arXiv preprint arXiv:1907.11692}, year = {2019}, } ``` diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index 21bc9a450e..ae1d5fffa1 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -123,8 +123,8 @@ def register_model_arch_fn(fn): # automatically import any Python files in the models/ directory for file in os.listdir(os.path.dirname(__file__)): - if file.endswith('.py') and not file.startswith('_'): - model_name = file[:file.find('.py')] + if not file.startswith('_'): + model_name = file[:file.find('.py')] if file.endswith('.py') else file module = importlib.import_module('fairseq.models.' + model_name) # extra `model_parser` for sphinx diff --git a/fairseq/models/roberta.py b/fairseq/models/roberta/model.py similarity index 73% rename from fairseq/models/roberta.py rename to fairseq/models/roberta/model.py index d8001e6d7c..946c19c899 100644 --- a/fairseq/models/roberta.py +++ b/fairseq/models/roberta/model.py @@ -13,7 +13,6 @@ import torch.nn.functional as F from fairseq import utils -from fairseq.data import encoders from fairseq.models import ( FairseqDecoder, FairseqLanguageModel, @@ -26,113 +25,7 @@ ) from fairseq.modules.transformer_sentence_encoder import init_bert_params - -class RobertaHubInterface(nn.Module): - """A simple PyTorch Hub interface to RoBERTa. - - Load RoBERTa:: - - >>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large') - >>> roberta.eval() # disable dropout (or leave in train mode to finetune) - - Apply Byte-Pair Encoding (BPE) to input text:: - - >>> tokens = roberta.encode('Hello world!') - >>> tokens - tensor([ 0, 31414, 232, 328, 2]) - - Extract features from RoBERTa:: - - >>> last_layer_features = roberta.extract_features(tokens) - >>> last_layer_features.size() - torch.Size([1, 5, 1024]) - - >>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True) - >>> len(all_layers) - 25 - >>> torch.all(all_layers[-1] == last_layer_features) - tensor(1, dtype=torch.uint8) - - Use RoBERTa for sentence-pair classification tasks:: - - >>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') # already finetuned - >>> roberta.eval() # disable dropout for evaluation - - >>> tokens = roberta.encode( - ... 'Roberta is a heavily optimized version of BERT.', - ... 'Roberta is not very optimized.' - ... ) - >>> roberta.predict('mnli', tokens).argmax() - tensor(0) # contradiction - - >>> tokens = roberta.encode( - ... 'Roberta is a heavily optimized version of BERT.', - ... 'Roberta is based on BERT.' - ... ) - >>> roberta.predict('mnli', tokens).argmax() - tensor(2) # entailment - - Register a new (randomly initialized) classification head:: - - >>> roberta.register_classification_head('new_task', num_classes=3) - >>> roberta.predict('new_task', tokens) - tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=) - - Using the GPU:: - - >>> roberta.cuda() - >>> roberta.predict('new_task', tokens) - tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=) - """ - - def __init__(self, args, task, model): - super().__init__() - self.args = args - self.task = task - self.model = model - - self.bpe = encoders.build_bpe(args) - - # this is useful for determining the device - self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float)) - - @property - def device(self): - return self._float_tensor.device - - def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor: - bpe_sentence = ' ' + self.bpe.encode(sentence) + ' ' - for s in addl_sentences: - bpe_sentence += ' ' + self.bpe.encode(s) - tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True) - return tokens.long() - - def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor: - if tokens.dim() == 1: - tokens = tokens.unsqueeze(0) - features, extra = self.model( - tokens.to(device=self.device), - features_only=True, - return_all_hiddens=return_all_hiddens, - ) - if return_all_hiddens: - # convert from T x B x C -> B x T x C - inner_states = extra['inner_states'] - return [inner_state.transpose(0, 1) for inner_state in inner_states] - else: - return features # just the last layer's features - - def register_classification_head( - self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs - ): - self.model.register_classification_head( - name, num_classes=num_classes, embedding_size=embedding_size, **kwargs - ) - - def predict(self, head: str, tokens: torch.LongTensor): - features = self.extract_features(tokens) - logits = self.model.classification_heads[head](features) - return F.log_softmax(logits, dim=-1) +from .hub_interface import RobertaHubInterface @register_model('roberta')