Skip to content

Commit

Permalink
Small optimizations (#5)
Browse files Browse the repository at this point in the history
* feat(test): add tqdm to get feedback when running locally

* fix(test): remove generation config warnings

* feat: compilation can be enabled only for decoding

This will only enable compilation for decoding. Note that there is not a
big speedup for now, probably due to slot increasing buffer size over
time, triggering recompilation.

* feat: logits post-processing happens on CPU

Logits post-processing is not very heavyweight, and doing it on CPU
actually accelerates decoding, because compilation is not re-triggered.

* fix: comparison to False should be `cond is False`
  • Loading branch information
tengomucho authored Mar 22, 2024
1 parent fdcd7ea commit a8452e7
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import copy
import logging
import time
import os
from abc import ABC
from enum import Enum
from typing import List, Optional, Tuple
from typing import List, Optional, Tuple, Dict

import torch
import torch_xla.core.xla_model as xm
Expand Down Expand Up @@ -181,7 +182,7 @@ def reset(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, s
selector: (`TokenSelector`):
An object implementing the updated token selection logic.
"""
self._tokens = input_ids.clone()
self._tokens = input_ids.cpu()
self._next_text_token_start = 0
self._next_text_token_end = torch.numel(self._tokens)
self._next_text = ""
Expand Down Expand Up @@ -210,25 +211,27 @@ def _decode_next_tokens(
self,
) -> str:
"""Hack to hopefully support generate_stream for the maximum number of tokenizers"""
# Copy the tokens to CPU to avoid recompilation on TPU. Post-processing is quite fast anyway.
tokens = self._tokens.cpu()
# We need to include the tokens that produced the last text to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
new_text = self._tokenizer.decode(self._tokens[self._next_text_token_start :], skip_special_tokens=False)
new_text = self._tokenizer.decode(tokens[self._next_text_token_start :], skip_special_tokens=False)
if new_text.endswith("�"):
# utf-8 char at the end means it's a potential unfinished byte sequence
# from byte fallback tokenization.
return ""

# Compare the generated text with the one using only the tokens producing the last one
last_text = self._tokenizer.decode(
self._tokens[self._next_text_token_start : self._next_text_token_end],
tokens[self._next_text_token_start : self._next_text_token_end],
skip_special_tokens=False,
)
if len(new_text) == len(last_text):
# Nothing new was actually generated
return ""
# Return the decoded text and store its token offsets
self._next_text_token_start = self._next_text_token_end
self._next_text_token_end = torch.numel(self._tokens)
self._next_text_token_end = torch.numel(tokens)
return new_text[len(last_text) :]

def append(self, next_token: int) -> str:
Expand All @@ -248,7 +251,7 @@ def append(self, next_token: int) -> str:
The corresponding decoded text (if any).
"""
self._tokens = torch.cat(
[self._tokens, torch.tensor([next_token], device=self._device, dtype=self._tokens.dtype)]
[self._tokens, torch.tensor([next_token], dtype=self._tokens.dtype)]
)
# Update mask only if it was set previously
if self._mask is not None:
Expand Down Expand Up @@ -304,6 +307,12 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
):
self.model = model
if model.device.type == "xla" and "DBG_COMPILE" in os.environ:
self.model_one_token = torch.compile(model, backend="openxla")
logger.debug("Model compiled for decoding")
else:
self.model_one_token = model

# Specify padding options for decoder-only architecture
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"
Expand Down Expand Up @@ -426,7 +435,7 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# Reset/clear KV cache
self.past_key_values = None
generation, next_batch = self._generate_token(
batch.id, input_ids, attention_mask=attention_mask, position_ids=position_ids, **extra_args
batch.id, input_ids, self.model, attention_mask=attention_mask, position_ids=position_ids, **extra_args
)

# Reactivate previously active slots for the next decode, and append
Expand Down Expand Up @@ -494,15 +503,17 @@ def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBa
else:
extra_args["attention_mask"] = attention_mask
extra_args["past_key_values"] = self.past_key_values
return self._generate_token(next_batch_id, input_ids, position_ids=position_ids, **extra_args)
return self._generate_token(
next_batch_id, input_ids, self.model_one_token, position_ids=position_ids, **extra_args
)

def _generate_token(
self, next_batch_id: int, input_ids: torch.LongTensor, **forward_extra_params
self, next_batch_id: int, input_ids: torch.LongTensor, model: torch.nn.Module, **forward_extra_params
) -> Tuple[List[Generation], CachedBatch]:
# Add barrier to allow next graph step to always be the same
xm.mark_step()
# Forward
outputs = self.model(
outputs = model(
input_ids,
return_dict=True,
use_cache=True,
Expand All @@ -512,8 +523,11 @@ def _generate_token(
# Save KV cache
self.past_key_values = outputs.past_key_values
# Barrier for XLA model
xm.mark_step(wait=False)
xm.mark_step()
ret = self._post_generate(outputs, next_batch_id, input_ids)
return ret

def _post_generate(self, outputs: Dict, next_batch_id: int, input_ids: torch.LongTensor) -> Tuple[List[Generation], CachedBatch]:
generations = []
active_slots = False
for i, slot in enumerate(self.slots):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,7 @@ def from_pretrained(
model.config.batch_size = batch_size
if sequence_length is not None or getattr(model.config, "sequence_length", None) is None:
model.config.sequence_length = sequence_length

# Do eval, and compile
# Do eval
model.eval()
if device == "xla" and "DBG_COMPILE" in environ:
model = torch.compile(model, backend="openxla_eval")
logger.debug("Model compiled.")

return model
15 changes: 11 additions & 4 deletions text-generation-inference/tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import os
from tqdm import tqdm
from text_generation_server.generator import TpuGenerator
from text_generation_server.model import fetch_model
from text_generation_server.pb.generate_pb2 import (
Expand Down Expand Up @@ -44,13 +45,19 @@ def create_request(
seed: int = 0,
repetition_penalty: float = 1.0,
):
# For these tests we can safely set typical_p to 1.0 (default)
typical_p = 1.0
if not do_sample:
# Drop top_p parameter to avoid warnings
top_p = 1.0
parameters = NextTokenChooserParameters(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
seed=seed,
repetition_penalty=repetition_penalty,
typical_p=typical_p,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters)
Expand Down Expand Up @@ -121,7 +128,7 @@ def test_decode_single(input_text, max_new_tokens, generated_text, do_sample, mo
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for i in range(max_new_tokens - 1):
for _ in tqdm(range(max_new_tokens - 1), "Decoding tokens"):
assert next_batch.size == 1
assert next_batch.max_tokens == 1024
assert len(generations) == 1
Expand Down Expand Up @@ -152,7 +159,7 @@ def test_decode_multiple(model_path):
assert len(tokens[0]) == 1
# Decode a few tokens
gen_tokens = 4
for _ in range(gen_tokens - 1):
for _ in tqdm(range(gen_tokens - 1), "Decoding tokens"):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
Expand All @@ -172,7 +179,7 @@ def test_decode_multiple(model_path):
assert len(tokens[1]) == 1
# Decode more tokens until we reach the maximum for the first request
batches = [next_batch, next_batch_1]
for _ in range(max_new_tokens - gen_tokens):
for _ in tqdm(range(max_new_tokens - gen_tokens), "Decoding tokens (2nd batch)"):
generations, next_batch = generator.decode(batches)
for g in generations:
tokens[g.request_id].append(g.tokens.ids[0])
Expand All @@ -189,7 +196,7 @@ def test_decode_multiple(model_path):
assert output.generated_tokens == max_new_tokens
generated_text = output.text
# Continue decoding until the end of the second request
for _ in range(gen_tokens - 1):
for _ in tqdm(range(gen_tokens - 1), "Decoding tokens (finishing)"):
generations, next_batch = generator.decode([next_batch])
assert len(generations) == 1
g = generations[0]
Expand Down
9 changes: 8 additions & 1 deletion text-generation-inference/tests/test_generator_gemma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
import os
from tqdm import tqdm
from text_generation_server.generator import TpuGenerator
from text_generation_server.model import fetch_model
from text_generation_server.pb.generate_pb2 import (
Expand Down Expand Up @@ -35,13 +36,19 @@ def create_request(
seed: int = 0,
repetition_penalty: float = 1.0,
):
# For these tests we can safely set typical_p to 1.0 (default)
typical_p = 1.0
if not do_sample:
# Drop top_p parameter to avoid warnings
top_p = 1.0
parameters = NextTokenChooserParameters(
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=do_sample,
seed=seed,
repetition_penalty=repetition_penalty,
typical_p=typical_p,
)
stopping_parameters = StoppingCriteriaParameters(max_new_tokens=max_new_tokens)
return Request(id=id, inputs=inputs, parameters=parameters, stopping_parameters=stopping_parameters)
Expand All @@ -57,7 +64,7 @@ def test_decode_single(model_path):
batch = Batch(id=0, requests=[request], size=1, max_tokens=SEQUENCE_LENGTH)
generations, next_batch = generator.prefill(batch)
# We already generated one token: call decode max_new_tokens - 1 times
for _ in range(max_new_tokens - 1):
for _ in tqdm(range(max_new_tokens - 1)):
assert next_batch.size == 1
assert next_batch.max_tokens == 1024
assert len(generations) == 1
Expand Down

0 comments on commit a8452e7

Please sign in to comment.