Skip to content

Commit

Permalink
Merge pull request #984 from Lyrcaxis/sample-tests
Browse files Browse the repository at this point in the history
Added sampling tests
  • Loading branch information
martindevans authored Nov 14, 2024
2 parents f340f31 + 0d1af94 commit 619bb5a
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 12 deletions.
186 changes: 186 additions & 0 deletions LLama.Unittest/SamplingTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
using LLama.Common;
using LLama.Native;

using System.Numerics.Tensors;
using System.Runtime.InteropServices;
using System.Text;

using Xunit.Abstractions;

namespace LLama.Unittest
{
public class SamplingTests : IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaWeights _model;
private readonly ModelParams _params;

private readonly LLamaBatch _batch;
private readonly StreamingTokenDecoder _decoder;

public void Dispose() => _model.Dispose();

public SamplingTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.GenerativeModelPath) {
ContextSize = 200,
BatchSize = 200,
GpuLayerCount = Constants.CIGpuLayerCount,
};
_model = LLamaWeights.LoadFromFile(_params);
_batch = new LLamaBatch();
_decoder = new(Encoding.UTF8, _model);
}


[Fact]
public void Sampling()
{
using var context = new LLamaContext(_model, _params);
var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8);
var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray();

// Add "I will repeat this phrase forever.\nI will", without requesting any logits.
for (int i = 0; i < tokens.Length; i++) { _batch.Add(token: tokens[i], pos: i, sequence: LLamaSeqId.Zero, logits: false); }
for (int i = 0; i < 2; i++) { _batch.Add(token: tokens[i], pos: tokens.Length + i, sequence: LLamaSeqId.Zero, logits: false); }

// Add " repeat" and test whether next tokens will be "this phrase forever.".
for (int i = 0; i < 4; i++)
{
_batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: LLamaSeqId.Zero, logits: true);
DecodeAndClear(context);

var expected = tokens[i + 3];
var logits = context.NativeHandle.GetLogits(numTokens: 1);

// Test raw sampling
Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits));

// Test native sampling with `LLamaTokenDataArrayNative`.
var array = LLamaTokenDataArray.Create(logits);
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
var rawLogits = new float[_model.VocabCount];
for (int j = 0; j < cur_p.Data.Length; j++)
{
rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit;
}
Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits));
}

// Test sampling chain
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle);
chain.Apply(ref cur_p);
Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID);
}

// Test logit bias
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle, logitBias);
chain.Apply(ref cur_p);
Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID);
}
}
}


[Fact]
public void BatchedSampling()
{
const int batch_count = 4;
using var context = new LLamaContext(_model, _params);
var tokens = _model.NativeHandle.Tokenize("I will repeat this phrase forever.\n", false, false, Encoding.UTF8);
var logitBias = tokens.Select(x => new LLamaLogitBias() { Token = x, Bias = -1000 }).ToArray();

// Add "I will repeat this phrase forever.\nI will", without requesting any logits.
for (int i = 0; i < tokens.Length + 2; i++)
{
for (int b = 0; b < batch_count; b++)
{
_batch.Add(token: tokens[i % tokens.Length], pos: i, sequence: (LLamaSeqId) b, logits: false);
}
}

// Add " repeat" and test whether next tokens will be "this phrase forever.".
for (int i = 0; i < 4; i++)
{
for (int b = 0; b < batch_count; b++)
{
_batch.Add(token: tokens[i + 2], pos: tokens.Length + i + 2, sequence: (LLamaSeqId) b, logits: true);
}
DecodeAndClear(context);

var expected = tokens[i + 3];
var all_logits = context.NativeHandle.GetLogits(numTokens: batch_count);

for (int b = 0; b < batch_count; b++)
{
var logits = all_logits.Slice(b * _model.VocabCount, _model.VocabCount);

// Test raw sampling
Assert.Equal(expected, TensorPrimitives.IndexOfMax(logits));

// Test native sampling with `LLamaTokenDataArrayNative`.
var array = LLamaTokenDataArray.Create(logits);
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
var rawLogits = new float[_model.VocabCount];
for (int j = 0; j < cur_p.Data.Length; j++)
{
rawLogits[(int) cur_p.Data[j].ID] = cur_p.Data[j].Logit;
}
Assert.Equal(expected, TensorPrimitives.IndexOfMax(rawLogits));
}

// Test sampling chain
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle);
chain.Apply(ref cur_p);
Assert.Equal(expected, cur_p.Data[(int) cur_p.Selected].ID);
}

// Test logit bias
{
using var _ = LLamaTokenDataArrayNative.Create(array, out var cur_p);
using var chain = CreateChain(context.NativeHandle, logitBias);
chain.Apply(ref cur_p);
Assert.NotEqual(expected, cur_p.Data[(int) cur_p.Selected].ID);
}
}
}
}


private void DecodeAndClear(LLamaContext context)
{
context.Decode(_batch);
_batch.Clear();
}

private static SafeLLamaSamplerChainHandle CreateChain(SafeLLamaContextHandle context, LLamaLogitBias[]? logit_bias = null)
{
var chain = SafeLLamaSamplerChainHandle.Create(LLamaSamplerChainParams.Default());

chain.AddPenalties(
vocabSize: context.VocabCount,
eos: context.ModelHandle.Tokens.EOS,
newline: context.ModelHandle.Tokens.Newline ?? 0,
penaltyCount: 60, repeat: 1, freq: 0, presence: 0,
penalizeNewline: false, ignoreEOS: false
);

if (logit_bias != null) { chain.AddLogitBias(context.VocabCount, logit_bias); }

chain.AddTopK(10);
chain.AddTemperature(0.1f);
chain.AddDistributionSampler(seed: 42);

return chain;
}
}
}
16 changes: 8 additions & 8 deletions LLama.Unittest/Transformers/PromptTemplateTransformerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ public void ToModelPrompt_FormatsCorrectly()

// Call once with empty array to discover length
var templateResult = PromptTemplateTransformer.ToModelPrompt(templater);
const string expected = "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nworld<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n111<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\naaa<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n222<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nbbb<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n333<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nccc<|eot_id|>"
const string expected = "<|start_header_id|>assistant<|end_header_id|>\n\nhello<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nworld<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n111<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\naaa<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n222<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nbbb<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n333<|eot_id|>"
+ "<|start_header_id|>user<|end_header_id|>\n\nccc<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n";

Assert.Equal(expected, templateResult);
Expand Down
14 changes: 10 additions & 4 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -431,21 +431,27 @@ public void ClearLoraAdapters()

#region GetLogits
/// <summary>
/// Token logits obtained from the last call to llama_decode
/// The logits for the last token are stored in the last row
/// Token logits obtained from the last call to llama_decode.
/// The logits for the last token are stored in the last row.
/// Only tokens with `logits = true` requested are present.<br/>
/// Can be mutated in order to change the probabilities of the next token.<br />
/// Rows: n_tokens<br />
/// Cols: n_vocab
/// </summary>
/// <param name="numTokens">
/// The amount of tokens whose logits should be retrieved, in <b>[numTokens X n_vocab]</b> format.<br/>
/// Tokens' order is based on their order in the LlamaBatch (so, first tokens are first, etc).<br/>
/// This is helpful when requesting logits for many tokens in a sequence, or want to decode multiple sequences in one go.
/// </param>
/// <returns></returns>
public Span<float> GetLogits()
public Span<float> GetLogits(int numTokens = 1)
{
var model = ThrowIfDisposed();

unsafe
{
var logits = llama_get_logits(this);
return new Span<float>(logits, model.VocabCount);
return new Span<float>(logits, model.VocabCount * numTokens);
}
}

Expand Down

0 comments on commit 619bb5a

Please sign in to comment.