-
Notifications
You must be signed in to change notification settings - Fork 58
Removed the assertion imposed on cu_seqlens_k and seqused_k #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Both FA2 and FA3 allow cu_seqlens_k and seqused_k to be non-None at the same time. For example, the following tests passed, where both cu_seqlens_k and seqused_k are non-None: ``` $ pytest ./tests/test_vllm_flash_attn.py -k test_flash_attn_varlen_output ... collected 2536 items / 2408 deselected / 128 selected Running 128 items in this shard tests/test_vllm_flash_attn.py ................................................................................................................ [ 87%] ................ [100%] ======================================================== 128 passed, 2408 deselected in 4.02s ======================================================== ``` This PR also fixed a minor compilation error Signed-off-by: Yang Chen <[email protected]>
@@ -388,7 +388,7 @@ class VarlenDynamicPersistentTileScheduler { | |||
// If Split, for the purpose of scheduling, we pretend that instead there are | |||
// (args.num_splits * args.num_head) number of heads. | |||
assert(args.tile_count_semaphore != nullptr); | |||
assert(num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx | |||
assert(args.num_head < (1 << 16)); // We use the top 16 bits to store num_splits & split_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how did we compile this before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we don't use the problematic version yet?
https://github.com/vllm-project/vllm/blob/main/cmake/external_projects/vllm_flash_attn.cmake#L41
I will update the tag once this PR gets approved and merged.
|
||
# Check that FlashAttention's numerical error is at most 3x | ||
# the numerical error of a Pytorch implementation. | ||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we have two checks, one check rtol, one check atol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's from FA's implementation:
https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py#L477-L479
I think it might be fine to keep their setup?
|
||
# Numerical error if we just do any arithmetic on out_ref | ||
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() | ||
rtol = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is rtol too large? should it be 1e-2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is copied from flash attention's value:
https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py#L185
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth citing in the comment?
@@ -29,6 +35,117 @@ | |||
([3] if is_fa_version_supported(3) else []) | |||
|
|||
|
|||
# This function is copied from hopper/test_utils.py | |||
def attention_ref( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this replace ref_attn
below? or vice versa? would be nice to have less duplication here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I tried it last Friday and it didn't work. The ref_attn
didn't work the new tests, and attention_ref
failed with some of the existing tests. I didn't spend effort digging into the issue though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW, if the overall change looks good, I am wondering if it would be fine to merge the PR while keeping two versions at the moment. We have some dependency on this change.
I will make sure to have a follow-up PR to unify them into one. I am sorry I am running out of my bandwidth for investigating the failure across each other at the moment. Thanks!
@@ -199,8 +199,6 @@ def flash_attn_varlen_func( | |||
""" | |||
assert cu_seqlens_k is not None or seqused_k is not None, \ | |||
"cu_seqlens_k or seqused_k must be provided" | |||
assert cu_seqlens_k is None or seqused_k is None, \ | |||
"cu_seqlens_k and seqused_k cannot be provided at the same time" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please explain to me the use case for using cu_seqlens_k
and seqused_k
simultaneously? Right now we use cu_seqlens_k
for initial prefills (no context) and seqused_k
for chunked prefill and decode (i.e. anytime there is a page table).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have a special use case for this but I can't expose the detail at the moment. Sorry about that!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please consider refactoring (mentioned in the comment) as a followup PR!
|
||
# Numerical error if we just do any arithmetic on out_ref | ||
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() | ||
rtol = 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe worth citing in the comment?
], | ||
) | ||
@pytest.mark.parametrize("fa_version", VERSIONS) | ||
def test_flash_attn_varlen_output( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make the test name more specific?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Spoke with @chenyang78 offline, we are going to hold off on this for now and use a branch (https://github.com/vllm-project/flash-attention/tree/main-0.8.2.post1) temporarily with this PR and #58 cherry picked onto the commit used by current vLLM main, this is due to accuracy issues (drop in gsm8k accuracy) stemming from bugs in the latest FA3 code
Repo: https://github.com/vllm-project/flash-attention.git
Commit: dc9d410b3e2d6534a4c70724c2515f4def670a22 <- current vLLM main
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 100
...
2025-04-01:18:53:05,675 INFO [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192), gen_kwargs: (None), limit: 100.0, num_fewshot: 5, batch_size: 1
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ | 0.49|± |0.0502|
| | |strict-match | 5|exact_match|↑ | 0.49|± |0.0502|
Repo: https://github.com/chenyang78/vllm-flash-attention
Commit: 70281734675afefa0bf36d2d87c399e84c75fd61 <- This PR
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 100
...
|Tasks|Version| Filter |n-shot| Metric | |Value| |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ | 0.46|± |0.0501|
| | |strict-match | 5|exact_match|↑ | 0.46|± |0.0501|
The accuracy issues seem to being #56 and the new dynamic split scheduler with the integration into vLLM main still in progress
(GSM8K accuracy checks will be added to the SOP for future upstream syncs, this shows unit tests are insufficient)
Thanks for the review, @WoosukKwon . I will make a follow-up PR to address the feedback. |
Sounds good. Thank you for helping out and making a workaround, Lucas! |
Since we are using a temporary branch, I think we can just update this PR at a later time |
Sounds good. Will do. Thanks. |
Both FA2 and FA3 allow cu_seqlens_k and seqused_k to be non-None at the same time. For example, the following tests passed, where both cu_seqlens_k and seqused_k are non-None:
This PR also fixed a minor compilation error