Skip to content

Commit

Permalink
Support flash-attention in forms of custom-call (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
ApsarasX authored May 8, 2024
1 parent df671df commit d1f9ce4
Show file tree
Hide file tree
Showing 17 changed files with 2,960 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ function run_xla_op_tests3 {
# NOTE: this line below is testing export and don't care about GPU
PJRT_DEVICE=CPU CPU_NUM_DEVICES=1 run_coverage "$CDIR/test_core_aten_ops.py"
run_test "$CDIR/test_pallas.py"
PJRT_DEVICE=CUDA pytest -v "$CDIR/test_flash_attn.py"

# CUDA tests
if [ -x "$(command -v nvidia-smi)" ]; then
Expand Down
1,695 changes: 1,695 additions & 0 deletions test/test_flash_attn.py

Large diffs are not rendered by default.

314 changes: 314 additions & 0 deletions torch_xla/core/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,317 @@ def forward(ctx, x):
@staticmethod
def backward(ctx, grad_output):
return xm.all_reduce(xm.REDUCE_SUM, grad_output)


def _flash_attn_fwd(query,
key,
value,
*,
dropout_rate=0.0,
scale=None,
is_causal=False,
alibi_slopes=None,
return_softmax=False):
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
query, key, value = [maybe_contiguous(x) for x in (query, key, value)]
return torch_xla._XLAC._xla_flash_attn_fwd(
query,
key,
value,
dropout_rate=dropout_rate,
scale=scale,
is_causal=is_causal,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax)


def _flash_attn_varlen_fwd(query,
key,
value,
cu_seqlens_query,
cu_seqlens_key,
*,
max_seqlen_q,
max_seqlen_k,
dropout_rate=0.0,
scale=None,
is_causal=False,
alibi_slopes=None,
return_softmax=False):
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
query, key, value = [maybe_contiguous(x) for x in (query, key, value)]
return torch_xla._XLAC._xla_flash_attn_varlen_fwd(
query,
key,
value,
cu_seqlens_query,
cu_seqlens_key,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_rate=dropout_rate,
scale=scale,
is_causal=is_causal,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax)


def _flash_attn_bwd(
grad_output,
query,
key,
value,
output,
softmax_lse,
rng_state,
*,
dropout_rate=0.0,
scale=None,
is_causal=False,
alibi_slopes=None,
deterministic=False,
):
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
grad_output, query, key, value, output, softmax_lse = [
maybe_contiguous(x)
for x in (grad_output, query, key, value, output, softmax_lse)
]
grad_query, grad_key, grad_value, grad_softmax = torch_xla._XLAC._xla_flash_attn_bwd(
grad_output,
query,
key,
value,
output,
softmax_lse,
rng_state,
dropout_rate=dropout_rate,
scale=scale,
is_causal=is_causal,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
)
return grad_query, grad_key, grad_value, grad_softmax


def _flash_attn_varlen_bwd(
grad_output,
query,
key,
value,
output,
softmax_lse,
rng_state,
cu_seqlens_query,
cu_seqlens_key,
*,
max_seqlen_q,
max_seqlen_k,
dropout_rate=0.0,
scale=None,
is_causal=False,
alibi_slopes=None,
deterministic=False,
):
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
grad_output, query, key, value, output = [
maybe_contiguous(x) for x in (grad_output, query, key, value, output)
]
grad_query, grad_key, grad_value, grad_softmax, = torch_xla._XLAC._xla_flash_attn_varlen_bwd(
grad_output,
query,
key,
value,
output,
softmax_lse,
rng_state,
cu_seqlens_query,
cu_seqlens_key,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_rate=dropout_rate,
scale=scale,
is_causal=is_causal,
alibi_slopes=alibi_slopes,
deterministic=deterministic,
)
return grad_query, grad_key, grad_value, grad_softmax


class FlashAttn(torch.autograd.Function):

@staticmethod
def forward(
ctx,
query,
key,
value,
dropout_rate,
scale,
is_causal,
alibi_slopes,
deterministic,
return_softmax,
):
if scale is None:
scale = query.shape[-1]**(-0.5)
output, softmax_lse, rng_state, S_dmask = _flash_attn_fwd(
query,
key,
value,
dropout_rate=dropout_rate,
scale=scale,
is_causal=is_causal,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_rate > 0)
ctx.save_for_backward(query, key, value, output, softmax_lse, rng_state)
ctx.dropout_rate = dropout_rate
ctx.scale = scale
ctx.is_causal = is_causal
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return output if not return_softmax else (output, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, grad_output, *args):
query, key, value, output, softmax_lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value, _ = _flash_attn_bwd(
grad_output,
query,
key,
value,
output,
softmax_lse,
rng_state,
dropout_rate=ctx.dropout_rate,
scale=ctx.scale,
is_causal=ctx.is_causal,
alibi_slopes=ctx.alibi_slopes,
deterministic=ctx.deterministic,
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None


class FlashAttnVarLen(torch.autograd.Function):

@staticmethod
def forward(
ctx,
query,
key,
value,
cu_seqlens_query,
cu_seqlens_key,
max_seqlen_q,
max_seqlen_k,
dropout_rate,
scale,
is_causal,
alibi_slopes,
deterministic,
return_softmax,
):
if scale is None:
scale = query.shape[-1]**(-0.5)
output, softmax_lse, rng_state, S_dmask = _flash_attn_varlen_fwd(
query,
key,
value,
cu_seqlens_query,
cu_seqlens_key,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_rate=dropout_rate,
scale=scale,
is_causal=is_causal,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_rate > 0)
ctx.save_for_backward(query, key, value, cu_seqlens_query, cu_seqlens_key,
output, softmax_lse, rng_state)
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.dropout_rate = dropout_rate
ctx.scale = scale
ctx.is_causal = is_causal
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return output if not return_softmax else (output, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, grad_output, *args):
query, key, value, cu_seqlens_query, cu_seqlens_key, output, softmax_lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value, _ = _flash_attn_varlen_bwd(
grad_output,
query,
key,
value,
output,
softmax_lse,
rng_state,
cu_seqlens_query,
cu_seqlens_key,
max_seqlen_q=ctx.max_seqlen_q,
max_seqlen_k=ctx.max_seqlen_k,
dropout_rate=ctx.dropout_rate,
scale=ctx.scale,
is_causal=ctx.is_causal,
alibi_slopes=ctx.alibi_slopes,
deterministic=ctx.deterministic,
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None, None


def flash_attn(
query,
key,
value,
*,
dropout_rate=0.0,
scale=None,
is_causal=False,
alibi_slopes=None,
deterministic=False,
return_softmax=False,
):
return FlashAttn.apply(
query,
key,
value,
dropout_rate,
scale,
is_causal,
alibi_slopes,
deterministic,
return_softmax,
)


def flash_attn_varlen(
query,
key,
value,
cu_seqlens_query,
cu_seqlens_key,
*,
max_seqlen_q,
max_seqlen_k,
dropout_rate=0.0,
scale=None,
is_causal=False,
alibi_slopes=None,
deterministic=False,
return_softmax=False,
):
return FlashAttnVarLen.apply(
query,
key,
value,
cu_seqlens_query,
cu_seqlens_key,
max_seqlen_q,
max_seqlen_k,
dropout_rate,
scale,
is_causal,
alibi_slopes,
deterministic,
return_softmax,
)
2 changes: 2 additions & 0 deletions torch_xla/csrc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ ptxla_cc_library(
"data_ops.cpp",
"debug_util.cpp",
"elementwise.cpp",
"flash_attn_util.cpp",
"helpers.cpp",
"ir_dump_util.cpp",
"matrix.cpp",
Expand Down Expand Up @@ -82,6 +83,7 @@ ptxla_cc_library(
"data_ops.h",
"debug_util.h",
"elementwise.h",
"flash_attn_util.h",
"generated_file_include.h",
"helpers.h",
"ir_dump_util.h",
Expand Down
Loading

0 comments on commit d1f9ce4

Please sign in to comment.