forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: fairinternal/fairseq-py#795 Differential Revision: D16620488 Pulled By: myleott fbshipit-source-id: 1998a9ccd8816fc7f590861fb4898f910a36bc1e
- Loading branch information
1 parent
5f34252
commit abb7ed4
Showing
14 changed files
with
530 additions
and
556 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
# Finetuning RoBERTa on GLUE tasks | ||
|
||
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands: | ||
```bash | ||
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py | ||
python download_glue_data.py --data_dir glue_data --tasks all | ||
``` | ||
|
||
### 2) Preprocess GLUE task data: | ||
```bash | ||
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name> | ||
``` | ||
`glue_task_name` is one of the following: | ||
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}` | ||
Use `ALL` for preprocessing all the glue tasks. | ||
|
||
### 3) Fine-tuning on GLUE task: | ||
Example fine-tuning cmd for `RTE` task | ||
```bash | ||
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16 | ||
WARMUP_UPDATES=122 # 6 percent of the number of updates | ||
LR=2e-05 # Peak LR for polynomial LR scheduler. | ||
NUM_CLASSES=2 | ||
MAX_SENTENCES=16 # Batch size. | ||
ROBERTA_PATH=/path/to/roberta/model.pt | ||
|
||
CUDA_VISIBLE_DEVICES=0 python train.py RTE-bin/ \ | ||
--restore-file $ROBERTA_PATH \ | ||
--max-positions 512 \ | ||
--max-sentences $MAX_SENTENCES \ | ||
--max-tokens 4400 \ | ||
--task sentence_prediction \ | ||
--reset-optimizer --reset-dataloader --reset-meters \ | ||
--required-batch-size-multiple 1 \ | ||
--init-token 0 --separator-token 2 \ | ||
--arch roberta_large \ | ||
--criterion sentence_prediction \ | ||
--num-classes $NUM_CLASSES \ | ||
--dropout 0.1 --attention-dropout 0.1 \ | ||
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \ | ||
--clip-norm 0.0 \ | ||
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ | ||
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \ | ||
--max-epoch 10 \ | ||
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric; | ||
``` | ||
|
||
For each of the GLUE task, you will need to use following cmd-line arguments: | ||
|
||
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B | ||
---|---|---|---|---|---|---|---|--- | ||
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1 | ||
`--lr` | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5 | ||
`--max-sentences` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16 | ||
`--total-num-update` | 123873 | 33112 | 113272 | 2036 | 20935 | 2296 | 5336 | 3598 | ||
`--warmup-updates` | 7432 | 1986 | 28318 | 122 | 1256 | 137 | 320 | 214 | ||
|
||
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`. | ||
|
||
**Note:** | ||
|
||
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--max-sentences=16/32` depending on the task. | ||
|
||
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--max-sentences`. | ||
|
||
c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search. |
Oops, something went wrong.