Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

LayerNorm, what is going on? #136

Open
ju-w opened this issue Dec 26, 2022 · 1 comment
Open

LayerNorm, what is going on? #136

ju-w opened this issue Dec 26, 2022 · 1 comment

Comments

@ju-w
Copy link

ju-w commented Dec 26, 2022

Hello!

I am unsure of what the LayerNorm on images is actually supposed to do. LayerNorm channels first works kinda like BatchNorm2d, however with quite suspicious vertical lines. LayerNorm channels last however completely breaks the image, as each pixel is normalized individually. Is this intended behaviour?

Please have a look at the experiment below, to understand what I mean.

Related #112 possibly #115

import torch
from torch import nn
import torchvision
from models.convnext import LayerNorm


img = torchvision.io.read_image("cat.jpg")
img = img*0.005
torchvision.utils.save_image(img, "cat0.jpg")

cat

Batch Norm 2d

bn = nn.BatchNorm2d(3)
img_bn = bn(img.unsqueeze(0))[0]
torchvision.utils.save_image(img_bn*.28+.7, "cat_bn.jpg")

cat_bn

Layer Norm

channels first

ln_cf = LayerNorm(3, data_format="channels_first")
img_cf = ln_cf(img)
torchvision.utils.save_image(img_cf*.28+.7, "cat_cf.jpg")

cat_cf

channels last

ln_cl = LayerNorm(3, data_format="channels_last")
img_cl = ln_cl(img.permute(1,2,0)).permute(2,0,1)
torchvision.utils.save_image(img_cl*.28+.7, "cat_cl.jpg")

cat_cl

@ju-w
Copy link
Author

ju-w commented Dec 26, 2022

See also: pytorch/pytorch#51455

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant