Skip to content

Commit

Permalink
Merge pull request #158 from microsoft/no-special-tokens
Browse files Browse the repository at this point in the history
Do not add special tokens while tokenizing
  • Loading branch information
sordonia authored Feb 17, 2025
2 parents 00882d6 + 63ba02d commit 6d77866
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 46 deletions.
16 changes: 14 additions & 2 deletions mttl/datamodule/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def add_space_and_eos(self, sources, labels):

# adds the eos token
labels_ = [
l + ((" " + self.tokenizer.eos_token) if self.add_eos_to_targets else "")
l + ((self.tokenizer.eos_token) if self.add_eos_to_targets else "")
for l in labels_
]
return sources_, labels_
Expand All @@ -271,6 +271,7 @@ def prepare_inputs_for_seq2seq_family(self, sources, labels):
padding=self.padding,
return_tensors=self.return_tensors,
truncation=True,
add_special_tokens=False,
)
tokenized_sources = self.tokenizer(
sources,
Expand All @@ -279,16 +280,21 @@ def prepare_inputs_for_seq2seq_family(self, sources, labels):
return_tensors=self.return_tensors,
truncation=True,
pad_to_multiple_of=self.pad_to_multiple_of,
add_special_tokens=False,
)
else:
tokenized_labels = self.tokenizer(
labels, padding="longest", return_tensors=self.return_tensors
labels,
padding="longest",
return_tensors=self.return_tensors,
add_special_tokens=False,
)
tokenized_sources = self.tokenizer(
sources,
padding="longest",
return_tensors=self.return_tensors,
pad_to_multiple_of=self.pad_to_multiple_of,
add_special_tokens=False,
)
label_mask = tokenized_labels["attention_mask"].bool()
masked_labels = tokenized_labels["input_ids"].masked_fill(
Expand Down Expand Up @@ -326,13 +332,15 @@ def prepare_inputs_for_gpt_family(self, sources, labels):
return output_batch

if self.max_input_length > 0:
# make sure we truncate sources...labels if needed
if self.tokenizer.truncation_side == "left":
tokenized_labels = self.tokenizer(
labels,
max_length=self.max_input_length,
padding=self.padding,
return_tensors=self.return_tensors,
truncation=True,
add_special_tokens=False,
)
else:
tokenized_sources = self.tokenizer(
Expand All @@ -341,6 +349,7 @@ def prepare_inputs_for_gpt_family(self, sources, labels):
padding=self.padding,
return_tensors=self.return_tensors,
truncation=True,
add_special_tokens=False,
)

tok_sources_plus_labels = self.tokenizer(
Expand All @@ -350,18 +359,21 @@ def prepare_inputs_for_gpt_family(self, sources, labels):
return_tensors=self.return_tensors,
truncation=True,
pad_to_multiple_of=self.pad_to_multiple_of,
add_special_tokens=False,
)
else:
tokenized_sources = self.tokenizer(
sources,
padding="longest",
return_tensors=self.return_tensors,
add_special_tokens=False,
)
tok_sources_plus_labels = self.tokenizer(
[i + t for i, t in zip(sources, labels)],
padding="longest",
return_tensors=self.return_tensors,
pad_to_multiple_of=self.pad_to_multiple_of,
add_special_tokens=False,
)

targets = tok_sources_plus_labels["input_ids"].clone()
Expand Down
4 changes: 2 additions & 2 deletions mttl/models/modifiers/hard_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def roll_along(arr, shifts, dim):
# move it to GPU
eps["input_ids"] = eps["input_ids"].to(input_ids.device)
eps["attention_mask"] = eps["attention_mask"].to(input_ids.device)
#

prompt_shifts = eps["attention_mask"].sum(1)
modify_labels = labels is not None and prompts[0].model_family == "gpt"

Expand All @@ -62,7 +62,7 @@ def roll_along(arr, shifts, dim):
if modify_labels:
labels = roll_along(labels, shifts, 1)

if padding_side == "right":
elif padding_side == "right":
# if padding side of tokenizer is right, then we move the padding to the left here
eps["input_ids"] = roll_along(eps["input_ids"], prompt_shifts, 1)
eps["attention_mask"] = roll_along(eps["attention_mask"], prompt_shifts, 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_adapter_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_expert_model_generate(tmp_path, create_dummy_expert, flan_data_module):
input_shift = batch["input_ids"].shape[1]

generation = module.generate(**batch, max_new_tokens=3)[:, input_shift:]
assert generation.cpu().numpy().tolist() == [[198, 198, 32]]
assert generation.cpu().numpy().tolist() == [[198, 198, 464]]

batch["attention_mask"][:1] = 0
generation = module.generate(**batch, max_new_tokens=3)[:, input_shift:]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_loglike_eval():
)
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125m")
result = evaluator.evaluate(model, num_batches=10)
assert np.allclose(result, 0.2, rtol=0.01)
assert np.allclose(result, 0.1, rtol=0.01)


def test_code_evaluator(mocker):
Expand Down
77 changes: 55 additions & 22 deletions tests/test_hard_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,11 @@ def setup_dataset(self):

@pytest.fixture
def dm_batch():
def _dm_batch(**kwargs):
def _dm_batch(
padding_side="right",
truncation_side="left",
for_generation=False,
):
dm = DummyDataModule(
DatasetConfig(
dataset="tiny_flan_id",
Expand All @@ -57,9 +61,10 @@ def _dm_batch(**kwargs):
max_input_length=1024,
train_batch_size=4,
predict_batch_size=2,
truncation_side="left",
padding_side=padding_side,
truncation_side=truncation_side,
),
**kwargs,
for_generation=for_generation,
)
dl = dm.val_dataloader()
batch = next(iter(dl))
Expand All @@ -80,38 +85,66 @@ def test_hard_prompt(padding_side, dm_batch):
)
text_1 = "This is a test prompt"
text_2 = "Test test"
padding_size = 5
prompt1 = HardPrompt(config, prompt_init=text_1)
prompt2 = HardPrompt(config, prompt_init=text_2)

flan_batch_for_generation = dm_batch(for_generation=True, val_mixin=False)
flan_batch_for_training = dm_batch()
flan_batch_for_training = dm_batch(padding_side=padding_side)

if padding_side == "left":
new_inputs = HardPrompt.parallel_forward(
[prompt1, prompt2], **flan_batch_for_generation
[prompt1, prompt2], **flan_batch_for_training
)
inputs_and_prompts, attn_masks, labels_and_prompts = new_inputs
assert tokenizer.batch_decode(inputs_and_prompts) == [
"This is a test prompt\nThis is a dev sentence",
"<|endoftext|>" * padding_size + "Test test\nThis is dev",
"<|endoftext|>This is a test prompt\nThis is a dev sentence a<|endoftext|>",
"<|endoftext|>" * 6 + "Test test\nThis is dev b<|endoftext|>",
]
assert torch.equal(
attn_masks,
torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
]
),
)
assert list(inputs_and_prompts.shape) == [2, 11]
assert list(inputs_and_prompts.shape) == [2, 14]
assert torch.equal(
labels_and_prompts,
torch.tensor(
[ # not sure if I understand the rolling of labels
[-100, -100, -100, -100, -100, -100, 220, 50256, 257],
[220, 50256, -100, -100, -100, -100, -100, -100, 275],
[
[
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
257,
50256,
],
[
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
-100,
275,
50256,
],
]
),
)
Expand All @@ -121,15 +154,15 @@ def test_hard_prompt(padding_side, dm_batch):
)
inputs_and_prompts, attn_masks, labels_and_prompts = new_inputs
assert tokenizer.batch_decode(inputs_and_prompts) == [
"This is a test prompt\nThis is a dev sentence a <|endoftext|>",
"Test test\nThis is dev b <|endoftext|>" + "<|endoftext|>" * padding_size,
"This is a test prompt\nThis is a dev sentence a<|endoftext|><|endoftext|>",
"Test test\nThis is dev b<|endoftext|>" + "<|endoftext|>" * 6,
]
assert torch.equal(
attn_masks,
torch.tensor(
[
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
]
),
)
Expand All @@ -150,8 +183,8 @@ def test_hard_prompt(padding_side, dm_batch):
-100,
-100,
257,
220,
50256,
-100,
],
[
-100,
Expand All @@ -161,13 +194,13 @@ def test_hard_prompt(padding_side, dm_batch):
-100,
-100,
275,
220,
50256,
-100,
-100,
-100,
-100,
-100,
-100,
],
]
),
Expand All @@ -185,7 +218,7 @@ def test_hard_prompt_eval(dm_batch):
"EleutherAI/gpt-neo-125m", model_family="gpt", for_generation=True
)

flan_batch_for_generation = dm_batch(for_generation=True, val_mixin=False)
flan_batch_for_generation = dm_batch(for_generation=True)
outputs = model.generate(
inputs=flan_batch_for_generation["input_ids"],
attention_mask=flan_batch_for_generation["attention_mask"],
Expand Down
2 changes: 1 addition & 1 deletion tests/test_packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_packing_and_attn(tiny_flan_id):
assert len(packed_ids) == sum([len(x) for x in input_ids])

# Check if the data is the one we expect (this can change if you change the model / tokenizer)
assert sum([sum(x) for x in input_ids]) == sum(packed_ids) == 3348702
assert sum([sum(x) for x in input_ids]) == sum(packed_ids) == 3348075

packed_batch = collator([packed_ds[0]])
input_batch = collator([ds[idx] for idx in range(first_seq_len)])
Expand Down
3 changes: 2 additions & 1 deletion tests/test_peer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from mttl.models.containers.peer_container import PEERMLPContainer
import numpy as np
import pytest

Expand All @@ -17,7 +18,7 @@ def test_peer_moe(tmp_peer_moe_config, dummy_batch):
)

output = module(**dummy_batch).loss
assert np.allclose(output.item(), 18.0, atol=0.1)
assert isinstance(module.experts_containers[0], PEERMLPContainer)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 6d77866

Please sign in to comment.