-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lack of clarity about sim function vs feature map for paper/code #1
Comments
Hi, thank you for your interest. Sorry for the confusion, by phi=x^2 we use simplified notation. From my perspective in case of sim(q, k) = (q^Tk)^2 we have phi(q) != q^2. However, we still can factorise our similarity into kernels. Check out tests with the reference implementation: rebased/flash_linear_attention/fla/layers/rebased_fast.py Lines 66 to 70 in fc11fa1
|
Still confused. These lines you link to rebased/flash_linear_attention/fla/layers/rebased_fast.py Lines 66 to 70 in fc11fa1
correspond to computing a self outer product and flattening. To clarify, it seems like you agree that the similarity function is
Agree? |
Yes, you are correct. I guess the source for your misunderstanding is messy with parallel and linear computing models. Like we stated above, If we want to factorize the multiplication of Managing this factorization is easy, since
Here is a simple snippet that could help you. import torch
def x_2(x: torch.Tensor):
# Get 2nd-order terms (rearrange(x * x), '... m n -> ... (m n)')
x2 = (x.unsqueeze(-1) * x.unsqueeze(-2)
).flatten(start_dim=-2)
return x2 # simple case without the normalization of attention scores
if __name__ == "__main__":
torch.manual_seed(5)
q = torch.randn(2, 3, 6)
k = torch.randn(2, 3, 6)
v = torch.randn(2, 3, 6)
q = q.view(2, 3, 1, -1).transpose(1, 2)
k = k.view(2, 3, 1, -1).transpose(1, 2)
v = v.view(2, 3, 1, -1).transpose(1, 2)
qk = torch.einsum("bhqd,bhkd->bhqk", q, k)
parallel_res = torch.einsum("bhqk,bhkd->bhqd", qk ** 2, v)
q, k, v = q.unsqueeze(-2), k.unsqueeze(-2), v.unsqueeze(-1)
linear_factorized_res = ((x_2(q) * (x_2(k) * v).sum(2, True)).sum(-1))
print(torch.max(parallel_res - linear_factorized_res)) # tensor(1.1444e-05) I have omitted the normalization of attention scores and mess with causal masking, but the general idea should be clear. |
Yes. I agree. The verification code can be even simpler import numpy as np
q = np.random.normal(0, 1, 64)
k = np.random.normal(0, 1, 64)
self_outer_flatten = lambda x: np.outer(x, x).flatten()
assert np.allclose(np.dot(q, k)**2, np.dot(self_outer_flatten(q), self_outer_flatten(k))) Good. We're on the same page there. Returning to my original post, and the title of this issue: your paper seems to be inconsistent/confused with your code.
Your paper seems to me to be conflating a similarity function (also called a kernel function) with a feature map (phi). There is no element-wise squaring anywhere. Your paper doesn't use the term "feature map" or "feature function" or "feature representation" anywhere. |
Noting that we provided a one-dimensional scenario in the paper for simplicity, we agree that the notations of the kernel function and feature mapping therein can be confusing. We will update it to make it more accurate. Thank you for pointing that out! |
Hi, I read your paper and found the following confusing. When you're describing your ablations which culminate in ReBased it starts with
but, this doesn't seem to be what happens in your code. See these lines
rebased/flash_linear_attention/fla/ops/triton/rebased_fast/parallel.py
Lines 68 to 69 in 7a085b4
My understanding of Linear Attention is the following. We need two functions: a similarity function (called
sim
ors
) which takes two vectors and returns a scalar and a feature map (calledphi
typically) which takes a single vector and returns another vector (possibly of different dimension). Ignoring normalization by1/sqrt(d)
for simplicity, Linear Attention requires thats(q, k) = dot(phi(q), phi(k))
Those lines of code I linked to correspond to defining
s(q, k) = dot(q, k)**2
The feature map
phi
which corresponds to this similarity function is not elementwise squaring. I.e.,phi(x) = x**2
is not the corresponding feature map for that similarity function. The correct corresponding feature map isphi(x) = flatten(outer(x, x))
One could say similar things about the other variants in the ablations, including ReBased.
Am I missing something?
The text was updated successfully, but these errors were encountered: