Skip to content

Commit

Permalink
More additions to Kron
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 27, 2025
1 parent f759d12 commit 71d1741
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions timm/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,6 @@
import numpy as np
import torch

try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
opt_einsum.enabled = True
opt_einsum.strategy = "auto-hq"
import torch.backends.opt_einsum
has_opt_einsum = True
except ImportError:
has_opt_einsum = False

try:
torch._dynamo.config.cache_size_limit = 1_000_000
Expand Down Expand Up @@ -67,19 +58,20 @@ class Kron(torch.optim.Optimizer):
params: Iterable of parameters to optimize or dicts defining parameter groups.
lr: Learning rate.
momentum: Momentum parameter.
weight_decay: Weight decay (L2 penalty).
weight_decay: Weight decay.
preconditioner_update_probability: Probability of updating the preconditioner.
If None, defaults to a schedule that anneals from 1.0 to 0.03 by 4000 steps.
max_size_triangular: Max size for dim's preconditioner to be triangular.
min_ndim_triangular: Minimum number of dimensions a layer needs to have triangular preconditioners.
memory_save_mode: 'one_diag', or 'all_diag', None is default
memory_save_mode: 'one_diag', 'smart_one_diag', or 'all_diag', None is default
to set all preconditioners to be triangular, 'one_diag' sets the largest
or last dim to be diagonal per layer, and 'all_diag' sets all preconditioners to be diagonal.
momentum_into_precond_update: whether to send momentum into preconditioner
update instead of raw gradients.
mu_dtype: Dtype of the momentum accumulator.
precond_dtype: Dtype of the preconditioner.
decoupled_decay: AdamW style decoupled-decay.
decoupled_decay: AdamW style decoupled weight decay
flatten_dim: Flatten dim >= 2 instead of relying on expressions
deterministic: Deterministic behaviour across save / load (resume). FIXME slow, needs work
"""

Expand All @@ -97,10 +89,18 @@ def __init__(
mu_dtype: Optional[torch.dtype] = None,
precond_dtype: Optional[torch.dtype] = None,
decoupled_decay: bool = False,
flatten_dim: bool = False,
deterministic: bool = False,
):
if not has_opt_einsum:
try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
opt_einsum.enabled = True
opt_einsum.strategy = "auto-hq"
import torch.backends.opt_einsum
except ImportError:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )

if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
if not 0.0 <= momentum < 1.0:
Expand All @@ -122,10 +122,11 @@ def __init__(
mu_dtype=mu_dtype,
precond_dtype=precond_dtype,
decoupled_decay=decoupled_decay,
flatten_dim=flatten_dim,
)
super(Kron, self).__init__(params, defaults)

self._param_exprs = {}
self._param_exprs = {} # cache for einsum expr
self._tiny = torch.finfo(torch.bfloat16).tiny
self.rng = random.Random(1337)
if deterministic:
Expand Down Expand Up @@ -165,20 +166,21 @@ def state_dict(self) -> Dict[str, Any]:

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
# Extract and remove the RNG state from the state dict
rng_state = state_dict.pop('rng_state', None)
torch_rng_state = state_dict.pop('torch_rng_state', None)
rng_states = {}
if 'rng_state' in state_dict:
rng_states['rng_state'] = state_dict.pop('rng_state')
if 'torch_rng_state' in state_dict:
rng_states['torch_rng_state'] = state_dict.pop('torch_rng_state')

# Load the optimizer state
super().load_state_dict(state_dict)
state_dict.update(rng_states) # add back

# Restore the RNG state if it exists
if rng_state is not None:
self.rng.setstate(rng_state)
state_dict['rng_state'] = rng_state # put it back if caller still using state_dict
if torch_rng_state is not None:
if self.torch_rng is not None:
self.torch_rng.set_state(torch_rng_state)
state_dict['torch_rng_state'] = torch_rng_state # put it back if caller still using state_dict
if 'rng_state' in rng_states:
self.rng.setstate(rng_states['rng_state'])
if 'torch_rng_state' in rng_states:
self.torch_rng.set_state(rng_states['torch_rng_state'])

def __setstate__(self, state):
super().__setstate__(state)
Expand Down Expand Up @@ -208,13 +210,16 @@ def step(self, closure=None):

grad = p.grad
state = self.state[p]
if group['flatten_dim']:
grad = grad.view(grad.size(0), -1)

if len(state) == 0:
state["step"] = 0
state["update_counter"] = 0
state["momentum_buffer"] = torch.zeros_like(p, dtype=mu_dtype or p.dtype)
state["momentum_buffer"] = torch.zeros_like(grad, dtype=mu_dtype or grad.dtype)
# init Q and einsum expressions on first step
state["Q"], exprs = _init_Q_exprs(
p,
grad,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
Expand All @@ -234,8 +239,9 @@ def step(self, closure=None):
total_precond_size += precond_size
total_precond_mb += precond_mb
elif p not in self._param_exprs:
# init only the einsum expressions, called after state load, Q are loaded from state_dict
exprs = _init_Q_exprs(
p,
grad,
group["precond_init_scale"],
group["max_size_triangular"],
group["min_ndim_triangular"],
Expand All @@ -245,6 +251,7 @@ def step(self, closure=None):
)
self._param_exprs[p] = exprs
else:
# retrieve cached expressions
exprs = self._param_exprs[p]

# update preconditioners all together deterministically
Expand Down Expand Up @@ -315,6 +322,8 @@ def step(self, closure=None):

# RMS of pre_grad should be 1.0, so let's cap at 1.1
pre_grad.mul_(torch.clamp(1.1 / (pre_grad.square().mean().sqrt_() + 1e-8), max=1.0))
if group['flatten_dim']:
pre_grad = pre_grad.view(p.shape)

# Apply weight decay
if group["weight_decay"] != 0:
Expand Down Expand Up @@ -369,9 +378,10 @@ def _init_Q_exprs(
dim_diag = [False for _ in shape]
dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "smart_one_diag":
dim_diag = [False for _ in shape]
# addition proposed by Lucas Nestler
rev_sorted_dims = np.argsort(shape)[::-1]
sorted_shape = sorted(shape)
dim_diag = [False for _ in shape]
if len(shape) >= 2 and sorted_shape[-1] > sorted_shape[-2]:
dim_diag[rev_sorted_dims[0]] = True
elif memory_save_mode == "all_diag":
Expand Down

0 comments on commit 71d1741

Please sign in to comment.