Skip to content

Commit

Permalink
Support different --max-positions and --tokens-per-sample
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: facebookresearch#924

Differential Revision: D16548165

Pulled By: myleott

fbshipit-source-id: 49569ece3e54fad7b4f0dfb201ac99123bfdd4f2
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Jul 29, 2019
1 parent 2fe45f0 commit 33597e5
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
4 changes: 4 additions & 0 deletions fairseq/models/roberta/hub_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def encode(self, sentence: str, *addl_sentences) -> torch.LongTensor:
def extract_features(self, tokens: torch.LongTensor, return_all_hiddens=False) -> torch.Tensor:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.size(-1) > self.model.max_positions():
raise ValueError('tokens exceeds maximum length: {} > {}'.format(
tokens.size(-1), self.model.max_positions()
))
features, extra = self.model(
tokens.to(device=self.device),
features_only=True,
Expand Down
2 changes: 2 additions & 0 deletions fairseq/models/roberta/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def add_args(parser):
help='dropout probability after activation in FFN')
parser.add_argument('--pooler-dropout', type=float, metavar='D',
help='dropout probability in the masked_lm pooler layers')
parser.add_argument('--max-positions', type=int,
help='number of positional embeddings to learn')

@classmethod
def build_model(cls, args, task):
Expand Down
2 changes: 0 additions & 2 deletions fairseq/tasks/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,6 @@ def is_beginning_of_word(i):
)

def build_dataset_for_inference(self, src_tokens, src_lengths, sort=True):
if self.args.also_lowercase_words:
raise NotImplementedError
src_dataset = PadDataset(
TokenBlockDataset(
src_tokens,
Expand Down

0 comments on commit 33597e5

Please sign in to comment.