Skip to content

Commit

Permalink
added byt5 option
Browse files Browse the repository at this point in the history
  • Loading branch information
u-brixton authored Aug 2, 2021
1 parent 5aa7921 commit 31ae9e0
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions simpletransformers/t5/t5_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm, trange
from transformers.models.t5 import T5Config, T5ForConditionalGeneration, T5Tokenizer
from transformers import ByT5Tokenizer, AutoTokenizer
from transformers.optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down Expand Up @@ -53,6 +54,7 @@ def chunks(lst, n):
MODEL_CLASSES = {
"t5": (T5Config, T5ForConditionalGeneration),
"mt5": (MT5Config, MT5ForConditionalGeneration),
"byt5": (T5Config, T5ForConditionalGeneration)
}


Expand Down Expand Up @@ -119,19 +121,21 @@ def __init__(
self.results = {}

config_class, model_class = MODEL_CLASSES[model_type]

if model_name is None:
self.config = self.args.config
self.model = model_class(config=self.config)
else:
self.config = config_class.from_pretrained(model_name, **self.args.config)
self.model = model_class.from_pretrained(model_name, config=self.config)

if isinstance(tokenizer, T5Tokenizer):
self.tokenizer = tokenizer
self.model.resize_token_embeddings(len(self.tokenizer))
else:
self.tokenizer = T5Tokenizer.from_pretrained(model_name, truncate=True)
if model_type != "byt5":
self.config = config_class.from_pretrained(model_name, **self.args.config)
self.model = model_class.from_pretrained(model_name, config=self.config)
else:
self.config = config_class.from_pretrained(model_name, **self.args.config)
self.model = model_class.from_pretrained(model_name)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
if model_type != "byt5":
if isinstance(tokenizer, T5Tokenizer):
self.tokenizer = tokenizer
self.model.resize_token_embeddings(len(self.tokenizer))

if self.args.dynamic_quantize:
self.model = torch.quantization.quantize_dynamic(
Expand Down

0 comments on commit 31ae9e0

Please sign in to comment.