Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom mask slows down attention. #724

Open
qiyuxinlin opened this issue Jan 8, 2025 · 9 comments
Open

Custom mask slows down attention. #724

qiyuxinlin opened this issue Jan 8, 2025 · 9 comments

Comments

@qiyuxinlin
Copy link

I noticed that in your previous version, you converted the float-type mask into a bit-packed array for mask usage. I would like to ask how much time this approach saves? I tested the execution time of the bit-packed array mask and the casual kernel, and I found that it runs about twice as slow. This still seems like a significant overhead.

@yzh119
Copy link
Collaborator

yzh119 commented Jan 8, 2025

Yes custom mask has significant overhead, the memory access pattern to custom mask is not coalesced.

For long sequence, it's encouraged to use the sparse API instead:
https://docs.flashinfer.ai/api/sparse.html

@ZhongYingMatrix
Copy link

Is there a better way to support tree attention (speculative decoding) than with a custom mask?

@yzh119
Copy link
Collaborator

yzh119 commented Jan 18, 2025

@ZhongYingMatrix as I mentioned, using the sparse attention API would be much faster, especially for long context.

See this unittest for their equivalence:
https://github.com/flashinfer-ai/flashinfer/blob/a0e99a3a820109763d9a757138a5cdf7bbcd1f85/tests/test_block_sparse.py

@qingquansong
Copy link

qingquansong commented Jan 18, 2025

Hey @yzh119 thanks for sharing the context! I have a similar issue but different context related. Suppose my mask is a dense block wise mask with the following type of mask structure, which is kind of dense for most part. Any API you suggest to use? I tried both but since my case have a quite dense block left side so both api seems have high latency overhead. The sparse one could be slightly faster if making it sparser but in general seems block wise sparsity is a bit hard to be leveraged well here to increase the speed. Other options is calling the api multiple times but may introduce high latency as well. (I'm inserting this to sglang currently) Thank you!

More general question would be,what kind of sparse ratio or pattern you would suggestion to use sparse api rather than dense one?

Also,curious about the timeline for supporting fp8 for FA3 in 0.2.1 and whether all sliding window and custom mask are supported jointly in FA3. Thank you in advance!)

Image
(The first part could be quite dense in the figure and could be quite long such as 32k X 24k and later small triangular blocks could have many (such as 50 small triangular blocks with 100 * 100 each) I'm trying 0.1.6 version btw.

@qingquansong
Copy link

qingquansong commented Feb 5, 2025

Hey @yzh119, slightly dig in a bit and some minor improvements we can make to speedup a bit:

  1. The custom mask segment bit packing kernel seems repeatedly done the job for each thread multiple times. The packbits one seems to be good so currently using that one 🤔

  2. Custom mask indexing and loading can be slightly improved a bit here

          const uint32_t byte_idx = idx >> 3;    // equivalent to idx / 8
          const uint32_t bit_idx  = idx & 7;       // equivalent to idx % 8
          const uint8_t mask_byte = __ldg(reinterpret_cast<const uint8_t*>(custom_mask + byte_idx));
          return ((mask_byte >> bit_idx) & 1)

          // then later do uint32_t mask_bit = (q_idx < qo_len) ? ((mask_byte >> bit_idx) & 1) : 1;

But this is pretty minor.

Working on more on the kernel side and see if we could have some more improvements and catcher a bit more with the casual one now.

@qiyuxinlin
Copy link
Author

Hey @yzh119, slightly dig in a bit and some minor improvements we can make to speedup a bit:

  1. The custom mask segment bit packing kernel seems repeatedly done the job for each thread 256 times. The packbits one seems to be good so currently using that one 🤔
  2. Custom mask indexing and loading can be slightly improved a bit here
          const uint32_t byte_idx = idx >> 3;    // equivalent to idx / 8
          const uint32_t bit_idx  = idx & 7;       // equivalent to idx % 8
          const uint8_t mask_byte = __ldg(reinterpret_cast<const uint8_t*>(custom_mask + byte_idx));
          return ((mask_byte >> bit_idx) & 1)

          // then later do uint32_t mask_bit = (q_idx < qo_len) ? ((mask_byte >> bit_idx) & 1) : 1;

But this is pretty minor.

Working on more on the kernel side and see if we could have some more improvements and catcher a bit more with the casual one now.

It feels like your mask can be implemented using casual. I think you can pass in two additional parameters to indicate which rows and columns are masked, without creating a complete custom mask.

@qingquansong
Copy link

qingquansong commented Feb 7, 2025

hey @qiyuxinlin do you mean two vectors for the delimeters as judge? Since you can see from the figure that each q has it's own attended kv and have a jointly shared kv to attend, I guess it's hard to use causal? ( I can try to provide one int to indicate the shared part for all the q's if that's what you suggested) Thank you in advance!

@qiyuxinlin
Copy link
Author

hey @qiyuxinlin do you mean two vectors for the delimeters as judge? Since you can see from the figure that each q has it's own attended kv and have a jointly shared kv to attend, I guess it's hard to use causal? ( I can try to provide one int to indicate the shared part for all the q's if that's what you suggested) Thank you in advance!

我从您的图中观察到,整体是一个下三角 casual mask,只是某些行、列以及部分块也被mask掉了,是否只传入这些信息就可以满足你的需求?我也只是猜测

@qingquansong
Copy link

hey @qiyuxinlin do you mean two vectors for the delimeters as judge? Since you can see from the figure that each q has it's own attended kv and have a jointly shared kv to attend, I guess it's hard to use causal? ( I can try to provide one int to indicate the shared part for all the q's if that's what you suggested) Thank you in advance!

我从您的图中观察到,整体是一个下三角 casual mask,只是某些行、列以及部分块也被mask掉了,是否只传入这些信息就可以满足你的需求?我也只是猜测

Thanks for the response! Yes. Totally agree that the best way is to pass in delimiter indices and avoid materializing the mask and put the mask logic inside. Let me think about some ways to partially avoid the construction of the mask (or maybe only need to create a smaller mask with some extra indices pass in. Appreciate the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants