Skip to content

Commit

Permalink
test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Feb 14, 2024
1 parent a2a412a commit f274d1e
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 12 deletions.
8 changes: 5 additions & 3 deletions keras_nlp/layers/modeling/transformer_layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,11 @@ def compute_causal_mask(batch_size, input_length, output_length, cache_index=0):
`(batch_size, output_length, input_length)` that can be passed to a
attention layer.
"""
i = ops.expand_dims(ops.arange(output_length), axis=1) + cache_index
j = ops.arange(input_length)
mask = ops.expand_dims(ops.cast(i >= j, dtype="int32"), axis=0)
i = ops.arange(output_length, dtype="float32")
i = i + ops.cast(cache_index, "float32")
i = ops.expand_dims(i, axis=1)
j = ops.arange(input_length, dtype="float32")
mask = ops.expand_dims(i >= j, axis=0)
return ops.broadcast_to(mask, (batch_size, output_length, input_length))


Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/bart/bart_seq_2_seq_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def call_with_cache(
):
tokens = self.backbone.token_embedding(token_ids)
positions = self.backbone.decoder_position_embedding(
tokens, start_index=index,
tokens, start_index=index
)
# Sum, normalize and apply dropout to embeddings.
x = self.backbone.decoder_embeddings_add((tokens, positions))
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/models/bart/bart_seq_2_seq_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,7 @@ def wrapper(*args, **kwargs):
cache,
)

with patch.object(
seq_2_seq_lm, "call_with_cache", wraps=wrapper
):
with patch.object(seq_2_seq_lm, "call_with_cache", wraps=wrapper):
inputs = {
"encoder_text": [
" airplane at airport",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def generate_postprocess(
token_ids = ops.convert_to_numpy(token_ids)
if not isinstance(padding_mask, tf.Tensor):
padding_mask = ops.convert_to_numpy(padding_mask)
# Make sure the numpy array has type `int32` since
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
token_ids = tf.cast(token_ids, "int32")
padding_mask = tf.cast(padding_mask, "bool")
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/models/mistral/mistral_causal_lm_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ def generate_postprocess(
# Convert the inputs to numpy arrays if they aren't a tensor already.
if not isinstance(token_ids, tf.Tensor):
token_ids = ops.convert_to_numpy(token_ids)
# Make sure the numpy array has type `int32` since
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
token_ids = token_ids.astype("int32")
if not isinstance(padding_mask, tf.Tensor):
padding_mask = ops.convert_to_numpy(padding_mask)
padding_mask = padding_mask.astype("bool")
# Make sure the numpy array has type `int32` since
# `SentencePieceProcessor.detokenize` only accepts `int32` arrays.
token_ids = tf.cast(token_ids, "int32")
padding_mask = tf.cast(padding_mask, "bool")
# Strip any special tokens during detokenization (e.g. the start and
# end markers). In the future we could make this configurable.
padding_mask = padding_mask & (token_ids != self.tokenizer.end_token_id)
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/samplers/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.beam_sampler import BeamSampler
from keras_nlp.samplers.contrastive_sampler import ContrastiveSampler
from keras_nlp.samplers.greedy_sampler import GreedySampler
from keras_nlp.samplers.random_sampler import RandomSampler
from keras_nlp.samplers.sampler import Sampler
from keras_nlp.samplers.top_k_sampler import TopKSampler
from keras_nlp.samplers.top_p_sampler import TopPSampler

Expand Down

0 comments on commit f274d1e

Please sign in to comment.