Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an isolated implementation of FlashDiffAttention #1633

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

zhuzilin
Copy link

@zhuzilin zhuzilin commented Oct 9, 2024

This PR is trying to implement a FlashDiffAttention class similar to the FlashSelfAttention in the origin flash attention repo (https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py#L53), so that training frameworks could easily add diff transformer support with and without varlen support.

The main idea is to set the num_head in the training process twice as the origin transformer so that we no longer need to change the code relates to RoPE.

A simple test script for the code is:

from dataclasses import dataclass

import torch
import torch.distributed as dist
from flash_attn.layers.rotary import RotaryEmbedding
from einops import rearrange

from multihead_flashdiff_2 import MultiheadFlashDiff2
from flashdiff import FlashDiffAttention
from kernel.rotary import apply_rotary_emb


@dataclass
class Args:
    model_parallel_size: int
    decoder_kv_attention_heads: int


def create_new_impl(origin_impl, head_dim, depth):
    diff_attn_func = FlashDiffAttention(
        head_dim=embed_dim // num_new_heads, depth=depth, causal=True
    ).to(device, dtype=dtype)
    # make the initialization the same
    diff_attn_func.lambda_q1.data.copy_(origin_impl.lambda_q1.data)
    diff_attn_func.lambda_k1.data.copy_(origin_impl.lambda_k1.data)
    diff_attn_func.lambda_q2.data.copy_(origin_impl.lambda_q2.data)
    diff_attn_func.lambda_k2.data.copy_(origin_impl.lambda_k2.data)
    #diff_attn_func.subln.weight.data.copy_(origin_impl.subln.weight.data)
    
    def new_impl(x, rel_pos):
        bsz, tgt_len, embed_dim = x.size()
        src_len = tgt_len

        q = origin_impl.q_proj(x)
        k = origin_impl.k_proj(x)
        v = origin_impl.v_proj(x)

        # here we no longer need "// 2"
        num_heads = embed_dim // head_dim
        num_kv_heads = k.shape[-1] // head_dim

        q = q.view(bsz, tgt_len, num_heads, head_dim)
        k = k.view(bsz, src_len, num_kv_heads, head_dim)
        v = v.view(bsz, src_len, num_kv_heads, head_dim)

        q = apply_rotary_emb(q, *rel_pos, interleaved=True)
        k = apply_rotary_emb(k, *rel_pos, interleaved=True)

        output = diff_attn_func(q, k, v)
        output = rearrange(output, '... H D -> ... (H D)')

        output = origin_impl.out_proj(output)
        return output
    
    return new_impl


if __name__ == "__main__":
    dist.init_process_group(backend="nccl")
    device = torch.device("cuda")
    dtype = torch.bfloat16
    args = Args(model_parallel_size=1, decoder_kv_attention_heads=4)
    batch_size = 2
    num_heads = 16
    seq_len = 512
    embed_dim = 2048
    depth = 12
    # in the new implementation, the num_heads should be twice the original num_heads
    num_new_heads = num_heads * 2
    head_dim = embed_dim // num_new_heads

    print("initializing modules")
    origin_impl = MultiheadFlashDiff2(args, embed_dim=embed_dim, depth=depth, num_heads=num_heads).to(device, dtype=dtype)
    new_impl = create_new_impl(origin_impl, head_dim, depth)

    print("creating test data")
    rotary_emb = RotaryEmbedding(
        head_dim,
        base=10000.0,
        interleaved=True,
        device=device,
    )
    rotary_emb._update_cos_sin_cache(seq_len, device=device, dtype=torch.bfloat16)
    rel_pos = (rotary_emb._cos_cached, rotary_emb._sin_cached)
    hidden_states = torch.randn((batch_size, seq_len, embed_dim), device=device, dtype=dtype)

    print("run origin forward")
    origin_output = origin_impl(hidden_states, rel_pos)

    print("run new forward")
    new_output = new_impl(hidden_states, rel_pos)

    assert torch.allclose(origin_output, new_output, atol=1e-6)

Thank you for your time on reviewing this PR.

@zhuzilin zhuzilin force-pushed the feature/flash_diff_attn branch 2 times, most recently from 200479b to f9f35a8 Compare October 9, 2024 06:29
@zhuzilin zhuzilin changed the title [WIP] Add a isolated implementation of FlashDiffAttention Add an isolated implementation of FlashDiffAttention Oct 9, 2024
@zhuzilin zhuzilin force-pushed the feature/flash_diff_attn branch from f9f35a8 to c6e6486 Compare October 9, 2024 06:36
@MarktHart
Copy link

You could go even closer to attention and use it as is with a doubled interleave. E.g.

def alternative_forward(
        self,
        x,
        rel_pos,
        attn_mask=None,
    ):
    bsz, tgt_len, embed_dim = x.size()
    src_len = tgt_len

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim)
    k = k.view(bsz, src_len, 2 * self.num_kv_heads, self.head_dim)
    v = v.view(bsz, src_len, self.num_kv_heads, 2 * self.head_dim)

    q = apply_rotary_emb(q, *rel_pos, interleaved=True)
    k = apply_rotary_emb(k, *rel_pos, interleaved=True)

    q = q.transpose(1, 2)
    
    k = torch.repeat_interleave(k.transpose(1, 2), dim=1, repeats=self.n_rep)
    v = torch.repeat_interleave(v.transpose(1, 2), dim=1, repeats=self.n_rep * 2)
    if attn_mask is None:
        attn_mask = torch.triu(
            torch.zeros([tgt_len, src_len])
            .float()
            .fill_(float("-inf"))
            .type_as(q),
            1 + src_len - tgt_len,
        )

    lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
    lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
    lambda_full = lambda_1 - lambda_2 + self.lambda_init

    attn_weights = F.scaled_dot_product_attention(query=q, key=k, value=v, attn_mask=attn_mask, scale=self.scaling)
    every_other_mask = torch.arange(attn_weights.size(1)) % 2 == 0
    attn = attn_weights[:, every_other_mask, :, :] - lambda_full * attn_weights[:, ~every_other_mask, :, :]

    attn = self.subln(attn)
    attn = attn * (1 - self.lambda_init)
    attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim)

    attn = self.out_proj(attn)
    return attn

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants