-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom mask slows down attention. #724
Comments
Yes custom mask has significant overhead, the memory access pattern to custom mask is not coalesced. For long sequence, it's encouraged to use the sparse API instead: |
Is there a better way to support tree attention (speculative decoding) than with a custom mask? |
@ZhongYingMatrix as I mentioned, using the sparse attention API would be much faster, especially for long context. See this unittest for their equivalence: |
Hey @yzh119 thanks for sharing the context! I have a similar issue but different context related. Suppose my mask is a dense block wise mask with the following type of mask structure, which is kind of dense for most part. Any API you suggest to use? I tried both but since my case have a quite dense block left side so both api seems have high latency overhead. The sparse one could be slightly faster if making it sparser but in general seems block wise sparsity is a bit hard to be leveraged well here to increase the speed. Other options is calling the api multiple times but may introduce high latency as well. (I'm inserting this to sglang currently) Thank you! More general question would be,what kind of sparse ratio or pattern you would suggestion to use sparse api rather than dense one? Also,curious about the timeline for supporting fp8 for FA3 in 0.2.1 and whether all sliding window and custom mask are supported jointly in FA3. Thank you in advance!)
|
Hey @yzh119, slightly dig in a bit and some minor improvements we can make to speedup a bit:
But this is pretty minor. Working on more on the kernel side and see if we could have some more improvements and catcher a bit more with the casual one now. |
It feels like your mask can be implemented using casual. I think you can pass in two additional parameters to indicate which rows and columns are masked, without creating a complete custom mask. |
hey @qiyuxinlin do you mean two vectors for the delimeters as judge? Since you can see from the figure that each q has it's own attended kv and have a jointly shared kv to attend, I guess it's hard to use causal? ( I can try to provide one int to indicate the shared part for all the q's if that's what you suggested) Thank you in advance! |
我从您的图中观察到,整体是一个下三角 casual mask,只是某些行、列以及部分块也被mask掉了,是否只传入这些信息就可以满足你的需求?我也只是猜测 |
Thanks for the response! Yes. Totally agree that the best way is to pass in delimiter indices and avoid materializing the mask and put the mask logic inside. Let me think about some ways to partially avoid the construction of the mask (or maybe only need to create a smaller mask with some extra indices pass in. Appreciate the help! |
I noticed that in your previous version, you converted the float-type mask into a bit-packed array for mask usage. I would like to ask how much time this approach saves? I tested the execution time of the bit-packed array mask and the casual kernel, and I found that it runs about twice as slow. This still seems like a significant overhead.
The text was updated successfully, but these errors were encountered: