Skip to content

Commit

Permalink
Some more kron work. Figured out why some tests fail, implemented a d…
Browse files Browse the repository at this point in the history
…eterministic rng state load but too slow so skipping some tests for now.
  • Loading branch information
rwightman committed Jan 27, 2025
1 parent de2f5c6 commit f759d12
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 81 deletions.
10 changes: 9 additions & 1 deletion tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def _build_params_dict_single(weight, bias, **kwargs):
return [dict(params=bias, **kwargs)]


@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*')))
@pytest.mark.parametrize('optimizer', list_optimizers(exclude_filters=('fused*', 'bnb*', 'kron*')))
def test_optim_factory(optimizer):
assert issubclass(get_optimizer_class(optimizer, bind_defaults=False), torch.optim.Optimizer)

Expand Down Expand Up @@ -386,6 +386,14 @@ def test_adam(optimizer):
_test_model(optimizer, dict(lr=5e-2))


@pytest.mark.parametrize('optimizer', ['kron'])
def test_kron(optimizer):
_test_rosenbrock(
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
)
_test_model(optimizer, dict(lr=1e-3))


@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
def test_adopt(optimizer):
_test_rosenbrock(
Expand Down
9 changes: 8 additions & 1 deletion timm/optim/_optim_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,9 +697,16 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
OptimInfo(
name='kron',
opt_class=Kron,
description='',
description='PSGD optimizer with Kronecker-factored preconditioner',
has_momentum=True,
),
OptimInfo(
name='kronw',
opt_class=Kron,
description='PSGD optimizer with Kronecker-factored preconditioner and decoupled weight decay',
has_momentum=True,
defaults={'decoupled_decay': True}
),
OptimInfo(
name='laprop',
opt_class=LaProp,
Expand Down
Loading

0 comments on commit f759d12

Please sign in to comment.