Skip to content

Commit

Permalink
Move opt_einsum import back out of class __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Jan 27, 2025
1 parent 71d1741 commit 80a0205
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions timm/optim/kron.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@

import numpy as np
import torch

try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
import torch.backends.opt_einsum
torch.backends.opt_einsum.enabled = True
torch.backends.opt_einsum.strategy = "auto-hq"
has_opt_einsum = True
except ImportError:
has_opt_einsum = False

try:
torch._dynamo.config.cache_size_limit = 1_000_000
Expand All @@ -26,11 +34,11 @@


def precond_update_prob_schedule(
n: float,
max_prob: float = 1.0,
min_prob: float = 0.03,
decay: float = 0.001,
flat_start: float = 500,
n: float,
max_prob: float = 1.0,
min_prob: float = 0.03,
decay: float = 0.001,
flat_start: float = 500,
) -> torch.Tensor:
"""Anneal preconditioner update probability during beginning of training.
Expand Down Expand Up @@ -92,14 +100,8 @@ def __init__(
flatten_dim: bool = False,
deterministic: bool = False,
):
try:
# NOTE opt_einsum needed to avoid blowing up memory with einsum ops
import opt_einsum
opt_einsum.enabled = True
opt_einsum.strategy = "auto-hq"
import torch.backends.opt_einsum
except ImportError:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer." )
if not has_opt_einsum:
warnings.warn("It is highly recommended to have 'opt_einsum' installed for this optimizer.")

if not 0.0 <= lr:
raise ValueError(f"Invalid learning rate: {lr}")
Expand Down

0 comments on commit 80a0205

Please sign in to comment.