Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rotate fast #43

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
File renamed without changes.
63 changes: 39 additions & 24 deletions sd_mecha/merge_methods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import functools
import math
import operator
import numpy as np
import torch
from scipy.stats import binom
from torch import Tensor
from typing import Tuple, TypeVar, Dict, Optional
from sd_mecha.hypers import Hyper
from .svd import orthogonal_procrustes, fractional_matrix_power
from .svd import orthogonal_procrustes, close_ortho_columns_full, fractional_orthogonal_matrix_power, MatmulIdentity
from sd_mecha.merge_space import MergeSpace
from sd_mecha.extensions.merge_method import LiftFlag, convert_to_recipe

Expand Down Expand Up @@ -524,18 +523,30 @@ def rotate(
cache: Optional[Dict[str, Dict[str, Tensor]]] = None,
**kwargs,
) -> Tensor | SameMergeSpace:
key = kwargs.get("key", "")
if key.endswith(("in_proj_weight", "in_proj_bias")):
# workaround for concatenated attention projection layers
vs = []
for i, k in enumerate(("to_q", "to_k", "to_v")):
k_kwargs = kwargs.copy()
k_kwargs["key"] = key.replace("in_proj_", f"{k}.")
dim = a.shape[0] // 3
t_start = dim*i
t_end = dim*(i+1)
k_models = tuple(m[t_start:t_end] for m in (a, b))
vs.append(rotate.__wrapped__(*k_models, alignment=alignment, alpha=alpha, **k_kwargs))
return torch.cat(vs)

if alignment == 0 and alpha == 0:
return a

if len(a.shape) < 2 or torch.allclose(a.half(), b.half()):
return weighted_sum.__wrapped__(a, b, alpha=alpha)
if len(a.shape) == 1 and key.endswith(("weight", "scale")):
return geometric_sum.__wrapped__(a, b, alpha=alpha)

is_conv = len(a.shape) == 4 and a.shape[-1] != 1
if is_conv:
shape_2d = (-1, functools.reduce(operator.mul, a.shape[2:]))
else:
shape_2d = (a.shape[0], a.shape[1:].numel())
if len(a.shape) < 2 or key.endswith("bias") or "position" in key or torch.allclose(a.half(), b.half()):
return weighted_sum.__wrapped__(a, b, alpha=alpha)

shape_2d = (a.shape[0], a.shape[1:].numel())
a_neurons = a.reshape(*shape_2d)
b_neurons = b.reshape(*shape_2d)
a_centroid = a_neurons.mean(0)
Expand All @@ -551,29 +562,33 @@ def rotate(
cache[key] = {}
cache = cache[key]

if cache is not None and "rotation" in cache:
rotation = transform = cache["rotation"].to(a.device, a.dtype)
if cache is not None and "u" in cache:
u = cache["u"].to(a.device, a.dtype)
vh = cache["vh"].to(a.device, a.dtype)
else:
rotation = transform = orthogonal_procrustes(a_neurons, b_neurons, cancel_reflection=alignment_is_float)
u, vh = orthogonal_procrustes(a_neurons, b_neurons, cancel_reflection=alignment_is_float)
if cache is not None:
cache["rotation"] = rotation.to("cpu", torch.float16)
cache["u"] = u.to("cpu", torch.bfloat16)
cache["vh"] = vh.to("cpu", torch.bfloat16)

def inverse_align(x): return (x @ vh.mH) @ u.mH

if alignment_is_float:
transform = fractional_matrix_power(transform, alignment, cache)
u, v, proj = close_ortho_columns_full(u, vh.mH)
del vh
def inverse_align(x): return (((x @ proj) @ v) @ u.mH) @ proj.mH
def transform(x): return ((x @ proj) @ fractional_orthogonal_matrix_power(u @ v.mH, alignment, cache)) @ proj.mH
elif alignment == 0:
transform = torch.eye(
len(transform),
dtype=transform.dtype,
device=transform.device,
)
def transform(x): return x
elif alignment != 1:
transform = torch.linalg.matrix_power(transform, round(alignment))
def transform(x): return x @ torch.linalg.matrix_power(u @ vh, round(alignment))
else:
def transform(x): return (x @ u) @ vh

if alpha != 0:
# interpolate the relationship between the neurons
a_neurons = weighted_sum.__wrapped__(a_neurons, b_neurons @ rotation.T, alpha=alpha)
if not math.isclose(alpha, 0):
a_neurons = (1-alpha)*a_neurons + alpha*inverse_align(b_neurons)

a_neurons @= transform
a_neurons = transform(a_neurons)
a_neurons += weighted_sum.__wrapped__(a_centroid, b_centroid, alpha=alignment)
return a_neurons.reshape_as(a)

Expand Down
Loading