Skip to content

Commit

Permalink
Add roberta.decode to hub interface to decode BPE (facebookresearch#931)
Browse files Browse the repository at this point in the history
Summary:
Fixes facebookresearch#930.
Pull Request resolved: facebookresearch#931

Differential Revision: D16562511

Pulled By: myleott

fbshipit-source-id: c4c07e2f067326b79daa547dcb3db84aeddbd555
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Jul 30, 2019
1 parent 3b2cecd commit d82517e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ $ tar -xzvf roberta.large.tar.gz
>>> tokens = roberta.encode('Hello world!')
>>> tokens
tensor([ 0, 31414, 232, 328, 2])
>>> roberta.decode(tokens)
'Hello world!'
```

##### Extract features from RoBERTa:
Expand Down
14 changes: 14 additions & 0 deletions fairseq/models/roberta/hub_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -38,6 +39,19 @@ def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor:
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=False)
return tokens.long()

def decode(self, tokens: torch.LongTensor):
assert tokens.dim() == 1
tokens = tokens.numpy()
if tokens[0] == self.task.source_dictionary.bos():
tokens = tokens[1:] # remove <s>
eos_mask = (tokens == self.task.source_dictionary.eos())
doc_mask = eos_mask[1:] & eos_mask[:-1]
sentences = np.split(tokens, doc_mask.nonzero()[0] + 1)
sentences = [self.bpe.decode(self.task.source_dictionary.string(s)) for s in sentences]
if len(sentences) == 1:
return sentences[0]
return sentences

def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
Expand Down

0 comments on commit d82517e

Please sign in to comment.