Skip to content

Create a user friendly inference demo #532

Open
@borisdayma

Description

@borisdayma

This is a feature request.

I like maxtext because it is very customizable and efficient for training.
The main issue I’m having is hacking away an inference function. The code is quite complex so not straightforward to do.
The simple decode.py works but it seems mainly experimental development for streaming.

I think streaming will be really cool, but we would also benefit from an easy model.generate(input_ids, attention_mask, params) function:

  • it should allow prefill based on the length of input_ids (user responsibility to try to supply not too many shapes to avoid recompilation)
  • it should allow batch input, with left padding to support different input length
  • should be compilable with jit/pjit
  • allow a few common sampling strategy: greedy, sample (with temperature, top k, top p), beam search
  • allow being used without a separate engine/service in case we want to make it part of a larger function that includes multiple models

This PR looked interesting: #402
I think that it was mainly for benchmarking though as it didn’t stop when the entire batch was eos but had a nice prefill functionality.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions