Skip to content

Commit d812344

Browse files
authored
Add decode stream (#64)
1 parent 0d137da commit d812344

File tree

10 files changed

+453
-120
lines changed

10 files changed

+453
-120
lines changed

lib/tokenizers/decode_stream.ex

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
defmodule Tokenizers.DecodeStream do
2+
@moduledoc """
3+
Implements streaming decoding functionality for tokenizers.
4+
"""
5+
6+
@enforce_keys [:resource]
7+
defstruct [:resource]
8+
9+
@type t :: %__MODULE__{
10+
resource: reference()
11+
}
12+
13+
@doc """
14+
Creates a new decode stream.
15+
16+
## Options
17+
18+
* `:skip_special_tokens` - determines whether special tokens should be
19+
skipped during decoding. By default, it is set to `false`.
20+
21+
"""
22+
@spec new(keyword()) :: t()
23+
def new(opts \\ []) when is_list(opts) do
24+
opts = Keyword.validate!(opts, skip_special_tokens: false)
25+
Tokenizers.Native.decoder_stream_new(opts[:skip_special_tokens])
26+
end
27+
28+
@doc """
29+
Steps through the decode stream with the given tokenizer and token ID.
30+
31+
Returns `{:ok, String.t()}` if there's a decoded string, or `{:ok, :out_ofr_range}` if the token ID is out of range.
32+
Returns `{:error, reason}` if an error occurs during decoding.
33+
"""
34+
def step(%__MODULE__{} = decode_stream, tokenizer, id) when is_integer(id) do
35+
case Tokenizers.Native.decoder_stream_step(decode_stream, tokenizer, id) do
36+
{:ok, decoded} when is_binary(decoded) ->
37+
{:ok, decoded}
38+
39+
{:ok, nil} ->
40+
{:ok, :out_of_range}
41+
42+
{:error, reason} ->
43+
{:error, reason}
44+
end
45+
end
46+
47+
@doc """
48+
Returns information about the decode stream state.
49+
"""
50+
defdelegate info(decode_stream), to: Tokenizers.Native, as: :decoder_stream_info
51+
52+
defimpl Inspect do
53+
import Inspect.Algebra
54+
alias Tokenizers.DecodeStream
55+
56+
def inspect(decode_stream, opts) do
57+
"#Tokenizers.DecodeStream<#{to_doc(DecodeStream.info(decode_stream), opts)}>"
58+
end
59+
end
60+
end

lib/tokenizers/native.ex

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ defmodule Tokenizers.Native do
3333
def decoders_ctc(_options), do: err()
3434
def decoders_sequence(_decoders), do: err()
3535

36+
# DecoderStream
37+
def decoder_stream_step(_decoder_stream, _tokenizer, _id), do: err()
38+
#
39+
def decoder_stream_info(_decoder_stream), do: err()
40+
#
41+
def decoder_stream_new(_skip_special_tokens), do: err()
42+
3643
# Encoding
3744
def encoding_get_length(_encoding), do: err()
3845
def encoding_get_n_sequences(_encoding), do: err()

0 commit comments

Comments
 (0)