Skip to content

Commit

Permalink
Add return_all_hiddens flag to hub interface
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#909

Differential Revision: D16532919

Pulled By: myleott

fbshipit-source-id: 16ce884cf3d84579026e4406a75ba3c01a128dbd
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Jul 27, 2019
1 parent 17fcc72 commit 40f1687
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
12 changes: 10 additions & 2 deletions examples/roberta/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Model | Description | # params | Download
```
>>> import torch
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
```

##### Apply Byte-Pair Encoding (BPE) to input text:
Expand All @@ -31,9 +32,16 @@ tensor([ 0, 31414, 232, 328, 2])

##### Extract features from RoBERTa:
```
>>> features = roberta.extract_features(tokens)
>>> features.size()
>>> last_layer_features = roberta.extract_features(tokens)
>>> last_layer_features.size()
torch.Size([1, 5, 1024])
>>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
>>> len(all_layers)
25
>>> torch.all(all_layers[-1] == last_layer_features)
tensor(1, dtype=torch.uint8)
```

##### Use RoBERTa for sentence-pair classification tasks:
Expand Down
26 changes: 21 additions & 5 deletions fairseq/models/roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class RobertaHubInterface(nn.Module):
Load RoBERTa::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
>>> roberta.eval() # disable dropout (or leave in train mode to finetune)
Apply Byte-Pair Encoding (BPE) to input text::
Expand All @@ -42,10 +43,16 @@ class RobertaHubInterface(nn.Module):
Extract features from RoBERTa::
>>> features = roberta.extract_features(tokens)
>>> features.size()
>>> last_layer_features = roberta.extract_features(tokens)
>>> last_layer_features.size()
torch.Size([1, 5, 1024])
>>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
>>> len(all_layers)
25
>>> torch.all(all_layers[-1] == last_layer_features)
tensor(1, dtype=torch.uint8)
Use RoBERTa for sentence-pair classification tasks::
>>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') # already finetuned
Expand Down Expand Up @@ -100,11 +107,20 @@ def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor:
tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True)
return tokens.long()

def extract_features(self, tokens: torch.LongTensor) -> torch.Tensor:
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
features, _ = self.model(tokens.to(device=self.device), features_only=True)
return features
features, extra = self.model(
tokens.to(device=self.device),
features_only=True,
return_all_hiddens=return_all_hiddens,
)
if return_all_hiddens:
# convert from T x B x C -> B x T x C
inner_states = extra['inner_states']
return [inner_state.transpose(0, 1) for inner_state in inner_states]
else:
return features # just the last layer's features

def register_classification_head(
self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
Expand Down

0 comments on commit 40f1687

Please sign in to comment.