mupify
. They'll be modified in-place so that the forward/backward passes reflect the example.ipynb
for a tutorial.
If you use this in a research project, please consider citing https://arxiv.org/abs/2404.19719!
Not intended for use with any of the following:
- Adaptive optimizers. (SGD + momentum and/or weight decay are fine.)
- Linear layers other than dense linear layers or 2d convolutions.
- Attention blocks
Important notes:
-
nn.ReLU()
layers are mupified to evaluate$\mathrm{max}(0, x\sqrt{2})$ rather than$\mathrm{max}(0, x)$ . To avoid this behavior, usetorch.functional.relu
- The user-facing functions are
mupify(model, optimizer, param)
andrescale(model, gamma)
. See documentation inmupify.py
.