diff --git a/examples/roberta/README.md b/examples/roberta/README.md
index e975789f01..21b04c845a 100644
--- a/examples/roberta/README.md
+++ b/examples/roberta/README.md
@@ -76,6 +76,28 @@ assert len(all_layers) == 25
assert torch.all(all_layers[-1] == last_layer_features)
```
+By default RoBERTa outputs one feature vector per BPE token. You can instead
+realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
+with the `extract_features_aligned_to_words` method. This will compute a
+weighted average of the BPE-level features for each word and expose them in
+spaCy's `Token.vector` attribute:
+```python
+doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
+assert len(doc) == 10
+for tok in doc:
+ print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
+# tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=) (...)
+# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=) (...)
+# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=) (...)
+# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=) (...)
+# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=) (...)
+# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=) (...)
+# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=) (...)
+# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...)
+# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=) (...)
+# tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=) (...)
+```
+
##### Use RoBERTa for sentence-pair classification tasks:
```python
# Download RoBERTa already finetuned for MNLI
diff --git a/fairseq/data/encoders/fastbpe.py b/fairseq/data/encoders/fastbpe.py
index 61a8f726ec..ed39b1bca9 100644
--- a/fairseq/data/encoders/fastbpe.py
+++ b/fairseq/data/encoders/fastbpe.py
@@ -25,7 +25,7 @@ def __init__(self, args):
self.bpe = fastBPE.fastBPE(codes)
self.bpe_symbol = "@@ "
except ImportError:
- raise ImportError('Please install fastbpe at https://github.com/glample/fastBPE')
+ raise ImportError('Please install fastBPE with: pip install fastBPE')
def encode(self, x: str) -> str:
return self.bpe.apply([x])[0]
diff --git a/fairseq/models/roberta/alignment_utils.py b/fairseq/models/roberta/alignment_utils.py
new file mode 100644
index 0000000000..85da2c4c01
--- /dev/null
+++ b/fairseq/models/roberta/alignment_utils.py
@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import Counter
+from typing import List
+
+import torch
+
+
+def align_bpe_to_words(roberta, bpe_tokens: torch.LongTensor, other_tokens: List[str]):
+ """
+ Helper to align GPT-2 BPE to other tokenization formats (e.g., spaCy).
+
+ Args:
+ roberta (RobertaHubInterface): RoBERTa instance
+ bpe_tokens (torch.LongTensor): GPT-2 BPE tokens of shape `(T_bpe)`
+ other_tokens (List[str]): other tokens of shape `(T_words)`
+
+ Returns:
+ List[str]: mapping from *other_tokens* to corresponding *bpe_tokens*.
+ """
+ assert bpe_tokens.dim() == 1
+
+ def clean(text):
+ return text.strip()
+
+ # remove whitespaces to simplify alignment
+ bpe_tokens = [roberta.task.source_dictionary.string([x]) for x in bpe_tokens]
+ bpe_tokens = [clean(roberta.bpe.decode(x) if x not in {'', ''} else x) for x in bpe_tokens]
+ other_tokens = [clean(str(o)) for o in other_tokens]
+
+ # strip leading
+ assert bpe_tokens[0] == ''
+ bpe_tokens = bpe_tokens[1:]
+ assert ''.join(bpe_tokens) == ''.join(other_tokens)
+
+ # create alignment from every word to a list of BPE tokens
+ alignment = []
+ bpe_toks = filter(lambda item: item[1] != '', enumerate(bpe_tokens, start=1))
+ j, bpe_tok = next(bpe_toks)
+ for other_tok in other_tokens:
+ bpe_indices = []
+ while True:
+ if other_tok.startswith(bpe_tok):
+ bpe_indices.append(j)
+ other_tok = other_tok[len(bpe_tok):]
+ try:
+ j, bpe_tok = next(bpe_toks)
+ except StopIteration:
+ j, bpe_tok = None, None
+ elif bpe_tok.startswith(other_tok):
+ # other_tok spans multiple BPE tokens
+ bpe_indices.append(j)
+ bpe_tok = bpe_tok[len(other_tok):]
+ other_tok = ''
+ else:
+ raise Exception('Cannot align "{}" and "{}"'.format(other_tok, bpe_tok))
+ if other_tok == '':
+ break
+ assert len(bpe_indices) > 0
+ alignment.append(bpe_indices)
+ assert len(alignment) == len(other_tokens)
+
+ return alignment
+
+
+def align_features_to_words(roberta, features, alignment):
+ """
+ Align given features to words.
+
+ Args:
+ roberta (RobertaHubInterface): RoBERTa instance
+ features (torch.Tensor): features to align of shape `(T_bpe x C)`
+ alignment: alignment between BPE tokens and words returned by
+ func:`align_bpe_to_words`.
+ """
+ assert features.dim() == 2
+
+ bpe_counts = Counter(j for bpe_indices in alignment for j in bpe_indices)
+ assert bpe_counts[0] == 0 # shouldn't be aligned
+ denom = features.new([bpe_counts.get(j, 1) for j in range(len(features))])
+ weighted_features = features / denom.unsqueeze(-1)
+
+ output = [weighted_features[0]]
+ largest_j = -1
+ for bpe_indices in alignment:
+ output.append(weighted_features[bpe_indices].sum(dim=0))
+ largest_j = max(largest_j, *bpe_indices)
+ for j in range(largest_j + 1, len(features)):
+ output.append(weighted_features[j])
+ output = torch.stack(output)
+ assert torch.all(torch.abs(output.sum(dim=0) - features.sum(dim=0)) < 1e-4)
+ return output
+
+
+def spacy_nlp():
+ if getattr(spacy_nlp, '_nlp', None) is None:
+ try:
+ from spacy.lang.en import English
+ spacy_nlp._nlp = English()
+ except ImportError:
+ raise ImportError('Please install spacy with: pip install spacy')
+ return spacy_nlp._nlp
+
+
+def spacy_tokenizer():
+ if getattr(spacy_tokenizer, '_tokenizer', None) is None:
+ try:
+ nlp = spacy_nlp()
+ spacy_tokenizer._tokenizer = nlp.Defaults.create_tokenizer(nlp)
+ except ImportError:
+ raise ImportError('Please install spacy with: pip install spacy')
+ return spacy_tokenizer._tokenizer
diff --git a/fairseq/models/roberta/hub_interface.py b/fairseq/models/roberta/hub_interface.py
index 2cc519746c..f7ba231e3b 100644
--- a/fairseq/models/roberta/hub_interface.py
+++ b/fairseq/models/roberta/hub_interface.py
@@ -3,6 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from typing import List
+
import numpy as np
import torch
import torch.nn as nn
@@ -72,7 +74,7 @@ def decode(self, tokens: torch.LongTensor):
return sentences[0]
return sentences
- def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
+ def extract_features(self, tokens: torch.LongTensor, return_all_hiddens: bool = False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > self.model.max_positions():
@@ -102,3 +104,32 @@ def predict(self, head: str, tokens: torch.LongTensor):
features = self.extract_features(tokens)
logits = self.model.classification_heads[head](features)
return F.log_softmax(logits, dim=-1)
+
+ def extract_features_aligned_to_words(self, sentence: str, return_all_hiddens: bool = False) -> torch.Tensor:
+ """Extract RoBERTa features, aligned to spaCy's word-level tokenizer."""
+ from fairseq.models.roberta import alignment_utils
+ from spacy.tokens import Doc
+
+ nlp = alignment_utils.spacy_nlp()
+ tokenizer = alignment_utils.spacy_tokenizer()
+
+ # tokenize both with GPT-2 BPE and spaCy
+ bpe_toks = self.encode(sentence)
+ spacy_toks = tokenizer(sentence)
+ spacy_toks_ws = [t.text_with_ws for t in tokenizer(sentence)]
+ alignment = alignment_utils.align_bpe_to_words(self, bpe_toks, spacy_toks_ws)
+
+ # extract features and align them
+ features = self.extract_features(bpe_toks, return_all_hiddens=return_all_hiddens)
+ features = features.squeeze(0)
+ aligned_feats = alignment_utils.align_features_to_words(self, features, alignment)
+
+ # wrap in spaCy Doc
+ doc = Doc(
+ nlp.vocab,
+ words=[''] + [x.text for x in spacy_toks] + [''],
+ spaces=[True] + [x.endswith(' ') for x in spacy_toks_ws[:-1]] + [True, False],
+ )
+ assert len(doc) == aligned_feats.size(0)
+ doc.user_token_hooks['vector'] = lambda token: aligned_feats[token.i]
+ return doc