diff --git a/examples/roberta/README.md b/examples/roberta/README.md index 1bc55ed7b4..52550157f3 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -20,6 +20,7 @@ Model | Description | # params | Download ``` >>> import torch >>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large') +>>> roberta.eval() # disable dropout (or leave in train mode to finetune) ``` ##### Apply Byte-Pair Encoding (BPE) to input text: @@ -31,9 +32,16 @@ tensor([ 0, 31414, 232, 328, 2]) ##### Extract features from RoBERTa: ``` ->>> features = roberta.extract_features(tokens) ->>> features.size() +>>> last_layer_features = roberta.extract_features(tokens) +>>> last_layer_features.size() torch.Size([1, 5, 1024]) + +>>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True) +>>> len(all_layers) +25 + +>>> torch.all(all_layers[-1] == last_layer_features) +tensor(1, dtype=torch.uint8) ``` ##### Use RoBERTa for sentence-pair classification tasks: diff --git a/fairseq/models/roberta.py b/fairseq/models/roberta.py index 42adfad303..d8001e6d7c 100644 --- a/fairseq/models/roberta.py +++ b/fairseq/models/roberta.py @@ -33,6 +33,7 @@ class RobertaHubInterface(nn.Module): Load RoBERTa:: >>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large') + >>> roberta.eval() # disable dropout (or leave in train mode to finetune) Apply Byte-Pair Encoding (BPE) to input text:: @@ -42,10 +43,16 @@ class RobertaHubInterface(nn.Module): Extract features from RoBERTa:: - >>> features = roberta.extract_features(tokens) - >>> features.size() + >>> last_layer_features = roberta.extract_features(tokens) + >>> last_layer_features.size() torch.Size([1, 5, 1024]) + >>> all_layers = roberta.extract_features(tokens, return_all_hiddens=True) + >>> len(all_layers) + 25 + >>> torch.all(all_layers[-1] == last_layer_features) + tensor(1, dtype=torch.uint8) + Use RoBERTa for sentence-pair classification tasks:: >>> roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli') # already finetuned @@ -100,11 +107,20 @@ def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor: tokens = self.task.source_dictionary.encode_line(bpe_sentence, append_eos=True) return tokens.long() - def extract_features(self, tokens: torch.LongTensor) -> torch.Tensor: + def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor: if tokens.dim() == 1: tokens = tokens.unsqueeze(0) - features, _ = self.model(tokens.to(device=self.device), features_only=True) - return features + features, extra = self.model( + tokens.to(device=self.device), + features_only=True, + return_all_hiddens=return_all_hiddens, + ) + if return_all_hiddens: + # convert from T x B x C -> B x T x C + inner_states = extra['inner_states'] + return [inner_state.transpose(0, 1) for inner_state in inner_states] + else: + return features # just the last layer's features def register_classification_head( self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs