Skip to content

Removed the assertion imposed on cu_seqlens_k and seqused_k #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

chenyang78
Copy link

Both FA2 and FA3 allow cu_seqlens_k and seqused_k to be non-None at the same time. For example, the following tests passed, where both cu_seqlens_k and seqused_k are non-None:

$ pytest ./tests/test_vllm_flash_attn.py -k test_flash_attn_varlen_output
...
collected 2536 items / 2408 deselected / 128 selected
Running 128 items in this shard

tests/test_vllm_flash_attn.py ................................................................................................................ [ 87%]
................                                                                                                                               [100%]

======================================================== 128 passed, 2408 deselected in 4.02s ========================================================

This PR also fixed a minor compilation error

Both FA2 and FA3 allow cu_seqlens_k and seqused_k to be
non-None at the same time. For example, the following
tests passed, where both cu_seqlens_k and seqused_k are
non-None:

```
$ pytest ./tests/test_vllm_flash_attn.py -k test_flash_attn_varlen_output
...
collected 2536 items / 2408 deselected / 128 selected
Running 128 items in this shard

tests/test_vllm_flash_attn.py ................................................................................................................ [ 87%]
................                                                                                                                               [100%]

======================================================== 128 passed, 2408 deselected in 4.02s ========================================================
```

This PR also fixed a minor compilation error

Signed-off-by: Yang Chen <[email protected]>
@@ -388,7 +388,7 @@ class VarlenDynamicPersistentTileScheduler {
// If Split, for the purpose of scheduling, we pretend that instead there are
// (args.num_splits * args.num_head) number of heads.
assert(args.tile_count_semaphore != nullptr);
assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx
assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how did we compile this before?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't use the problematic version yet?

https://github.com/vllm-project/vllm/blob/main/cmake/external_projects/vllm_flash_attn.cmake#L41

I will update the tag once this PR gets approved and merged.


# Check that FlashAttention's numerical error is at most 3x
# the numerical error of a Pytorch implementation.
assert (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have two checks, one check rtol, one check atol?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's from FA's implementation:

https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py#L477-L479

I think it might be fine to keep their setup?


# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is rtol too large? should it be 1e-2?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth citing in the comment?

@@ -29,6 +35,117 @@
([3] if is_fa_version_supported(3) else [])


# This function is copied from hopper/test_utils.py
def attention_ref(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this replace ref_attn below? or vice versa? would be nice to have less duplication here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I tried it last Friday and it didn't work. The ref_attn didn't work the new tests, and attention_ref failed with some of the existing tests. I didn't spend effort digging into the issue though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, if the overall change looks good, I am wondering if it would be fine to merge the PR while keeping two versions at the moment. We have some dependency on this change.

I will make sure to have a follow-up PR to unify them into one. I am sorry I am running out of my bandwidth for investigating the failure across each other at the moment. Thanks!

@@ -199,8 +199,6 @@ def flash_attn_varlen_func(
"""
assert cu_seqlens_k is not None or seqused_k is not None, \
"cu_seqlens_k or seqused_k must be provided"
assert cu_seqlens_k is None or seqused_k is None, \
"cu_seqlens_k and seqused_k cannot be provided at the same time"
Copy link
Collaborator

@LucasWilkinson LucasWilkinson Mar 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please explain to me the use case for using cu_seqlens_k and seqused_k simultaneously? Right now we use cu_seqlens_k for initial prefills (no context) and seqused_k for chunked prefill and decode (i.e. anytime there is a page table).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a special use case for this but I can't expose the detail at the moment. Sorry about that!

Copy link

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please consider refactoring (mentioned in the comment) as a followup PR!


# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe worth citing in the comment?

],
)
@pytest.mark.parametrize("fa_version", VERSIONS)
def test_flash_attn_varlen_output(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make the test name more specific?

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke with @chenyang78 offline, we are going to hold off on this for now and use a branch (https://github.com/vllm-project/flash-attention/tree/main-0.8.2.post1) temporarily with this PR and #58 cherry picked onto the commit used by current vLLM main, this is due to accuracy issues (drop in gsm8k accuracy) stemming from bugs in the latest FA3 code

Repo:   https://github.com/vllm-project/flash-attention.git
Commit: dc9d410b3e2d6534a4c70724c2515f4def670a22 <- current vLLM main

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 100
...
2025-04-01:18:53:05,675 INFO     [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.49|±  |0.0502|
|     |       |strict-match    |     5|exact_match|↑  | 0.49|±  |0.0502|


Repo:   https://github.com/chenyang78/vllm-flash-attention
Commit: 70281734675afefa0bf36d2d87c399e84c75fd61 <- This PR

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 100
...
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.46|±  |0.0501|
|     |       |strict-match    |     5|exact_match|↑  | 0.46|±  |0.0501|

The accuracy issues seem to being #56 and the new dynamic split scheduler with the integration into vLLM main still in progress

(GSM8K accuracy checks will be added to the SOP for future upstream syncs, this shows unit tests are insufficient)

@chenyang78
Copy link
Author

LGTM. Please consider refactoring (mentioned in the comment) as a followup PR!

Thanks for the review, @WoosukKwon . I will make a follow-up PR to address the feedback.

@chenyang78
Copy link
Author

Spoke with @chenyang78 offline, we are going to hold off on this for now and use a branch (https://github.com/vllm-project/flash-attention/tree/main-0.8.2.post1) temporarily with this PR and #58 cherry picked onto the commit used by current vLLM main, this is due to accuracy issues (drop in gsm8k accuracy) stemming from bugs in the latest FA3 code

Sounds good. Thank you for helping out and making a workaround, Lucas!

@LucasWilkinson
Copy link
Collaborator

LGTM. Please consider refactoring (mentioned in the comment) as a followup PR!

Thanks for the review, @WoosukKwon . I will make a follow-up PR to address the feedback.

Since we are using a temporary branch, I think we can just update this PR at a later time

@chenyang78
Copy link
Author

LGTM. Please consider refactoring (mentioned in the comment) as a followup PR!

Thanks for the review, @WoosukKwon . I will make a follow-up PR to address the feedback.

Since we are using a temporary branch, I think we can just update this PR at a later time

Sounds good. Will do. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants