Skip to content

2025/03/10/sampling #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
utterances-bot opened this issue Mar 11, 2025 · 15 comments
Open

2025/03/10/sampling #4

utterances-bot opened this issue Mar 11, 2025 · 15 comments

Comments

@utterances-bot
Copy link

utterances-bot commented Mar 11, 2025

Sorting-Free GPU Kernels for LLM Sampling | FlashInfer

Background

https://flashinfer.ai/2025/03/10/sampling.html

Copy link

Thanks a lot for the great introduction, this is super helpful! quick question about the curve of sampling latency scaling with batch size, wondering why is there a bump of latency from around 130 to 140 for PyTorch.

@yzh119
Copy link
Collaborator

yzh119 commented Apr 4, 2025

Hi @platypus1989 it might because of a change of kernel choice (e.g. grid size bump from 128 to 256) for different batch sizes. Maybe @xslingcn can provide the raw trace file and check the kernel configuration under different batch size.

Copy link

can we use gumbel sampling to use only argmax for sampling?
e.g., https://github.com/NonvolatileMemory/fast_llm_sampling/tree/main

@yzh119
Copy link
Collaborator

yzh119 commented Apr 15, 2025

Sure using gumbel sampling is a great idea, the only issue is we have to generate much more random numbers (vocab size) than inverse transform sampling (just one).

@NonvolatileMemory
Copy link

maybe we can pre generate and cache them

@yzh119
Copy link
Collaborator

yzh119 commented Apr 15, 2025

maybe we can pre generate and cache them

Then you have to load these gumbel random numbers, doubling the I/O of the kernel (sampling without filtering is bounded by I/O).

@yzh119
Copy link
Collaborator

yzh119 commented Apr 30, 2025

Hi @NonvolatileMemory , flashinfer-ai/flashinfer#1035 provides a reference implementation of gumbel sampling that generates gumbel random number on-the-fly, which is slower than inverse transform sampling from probability, but faster than softmax + sampling from probability.

Copy link

rain7996 commented May 6, 2025

Hi @yzh119 , can the Dual Pivot Rejection Sampling generate the same probability distribution as the classical sorting sampling methods which just selected the minimal token set in top_p sampling?

Copy link
Collaborator

yzh119 commented May 6, 2025

Hi @rain7996, yes Dual Pivot Rejection Sampling helps fast identifynig the top-p/k threshold, and shouldn't change the probability distribution, you can check our unittests: https://github.com/flashinfer-ai/flashinfer/blob/main/tests/test_sampling.py#L81-L103

Copy link

rain7996 commented May 6, 2025

Is there any theoretical analysis? Why shall we choose the token "and" in the animation instead of any other remaining tokens which and the selected token "and" make up the smallest set with larger probability than top_p?

Copy link
Collaborator

yzh119 commented May 6, 2025

Is there any theoretical analysis?

A proof is attached for your reference, please check the Theoretical Proof of the Correctness of Rejection Sampler section.

Why shall we choose the token "and" in the animation instead of any other remaining tokens

It totally depends on the value of uniform random $u$ (which is random number sampled independently from curand for each round), the animation show the behavior when $u=0.33$ at round 2 but it could be any value between 0 to 1.

Copy link

Great work for the new rejection sampling algorithm and proof!
I suppose you can use arbitrary number of pivot to speculatively eliminate the sample steps, which leads to less prob tensor memory scanning.
What's the performance gain of dual pivot vs single pivot sampling in practical cases?

Copy link

rain7996 commented May 8, 2025

@yzh119 Great, thanks for your explanation!

Copy link
Contributor

xslingcn commented May 8, 2025

Hi @Stonesjtu, sure that's a good point! We ran some experiments with the top-p/top-k renorm kernels (flashinfer-ai/flashinfer#974 ), and showed that increasing the number of pivots doesn't infinitely improve the performance. The trend should generally align with sampling kernels. Detailed performance comparison can be found in the spreadsheet.

Another major reason we introduced dual pivot is that we want to make sure the algorithm stops within a determinable number of max rounds. Althought the single pivot method may select an acceptable prob faster under certain configurations, it will struggle with some extremely skewed distributions, which are common for llm outputs.

Copy link

In terms of worst case steps: a strong boost for N -> log2(N), but minor improvement for log2(N) --> log3(N), which may not compensate the complexity introduced by extra pivot.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants