Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel SageAttention Inference #50

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Conversation

DefTruth
Copy link

@DefTruth DefTruth commented Nov 26, 2024

This PR add a small workaround that can make sage attention work compatible with distributed env, for example, xDiT which will launch by torchrun. Without this workaround, sage attention will run into illegal memory access error after first inference step in distributed env for multi gpus inference. This small workaround also make sage attention work compatible with torch.compile through non-fullgraph compile mode.

@jason-huang03

@DefTruth DefTruth changed the title Parallel SageAttention Inference Support Parallel SageAttention Inference Nov 26, 2024
@jason-huang03
Copy link
Member

Thanks a lot! We well check the implementation and merge the PR.

@jason-huang03 jason-huang03 self-assigned this Nov 26, 2024
@DefTruth
Copy link
Author

DefTruth commented Nov 26, 2024

may need to install latest xDiT from source if your env already have FA>=2.7.0, i just make a hotfix to ensure ring flash attn forward compatible with lastest FA and thus will not run into an func launch error.

also, the plug-and-play sage attention can only work with CFG parallelism for 2 GPUs now.

@jason-huang03
Copy link
Member

I am busy these days and I will dive into it as soon as I can.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants