Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect RMSNorm #4

Open
arunmallya opened this issue Mar 13, 2024 · 3 comments
Open

Incorrect RMSNorm #4

arunmallya opened this issue Mar 13, 2024 · 3 comments

Comments

@arunmallya
Copy link

The RMSNorm implementation in this codebase in wrong as it computes the RMS over the (T, D) dimensions instead of the (D) dimension. Assume input x is of shape (B, T, D).

The current code does this:

# x is (B, T, D).
ff_rms = torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5  # (B,).
raw = x / ff_rms.unsqueeze(-1).unsqueeze(-1)  # (B, 1, 1).

The original RMSNorm is here - https://github.com/meta-llama/llama/blob/main/llama/model.py#L34-L77

x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

The correct version using Frobenius norm would be:

ff_rms = torch.linalg.norm(x, dim=-1, keepdims=True) / math.sqrt(x.shape[-1])  # (B, T, 1).
raw = x / (ff_rms + eps)

Normalization should be per-token, not per-sequence.

@nkkbr
Copy link

nkkbr commented May 12, 2024

I agree with you.

@nkkbr
Copy link

nkkbr commented May 12, 2024

My version:

class RMSNorm(nn.Module):
    def __init__(self,layer_shape,eps=1e-8,bias=False):
        super(RMSNorm,self).__init__()
        self.register_parameter('scale',nn.Parameter(torch.ones(layer_shape)))
        self.eps=eps

    def forward(self,x):
        """
        assumes shape is (batch,seq_len,d_model)
        """
        f = torch.rsqrt((torch.mean(pow(x,2),dim=-1,keepdim=True)+self.eps))
        return x*f*self.scale[:x.shape[1],:].unsqueeze(0)

@bkitano
Copy link
Owner

bkitano commented May 29, 2024

hi! open a PR?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants