Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NearSwap #480

Merged
merged 13 commits into from
Jan 25, 2025
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ A quick overview of the currently supported merge methods:
| ------------------------------------------------------------------------------------------------ | -------------------- | ----------- | --------------- |
| Linear ([Model Soups](https://arxiv.org/abs/2203.05482)) | `linear` | ✅ | ❌ |
| SLERP | `slerp` | ❌ | ✅ |
| Nearswap | `nearswap` | ❌ | ✅ |
| [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `task_arithmetic` | ✅ | ✅ |
| [TIES](https://arxiv.org/abs/2306.01708) | `ties` | ✅ | ✅ |
| [DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708) | `dare_ties` | ✅ | ✅ |
Expand Down Expand Up @@ -272,6 +273,14 @@ Parameters:

- `t` - interpolation factor. At `t=0` will return `base_model`, at `t=1` will return the other one.

### Nearswap

Interpolates base model with secondary model if similarity is below t. Accepts two models.

Parameters:

- `t` - similarity threshold

### [Task Arithmetic](https://arxiv.org/abs/2212.04089)

Computes "task vectors" for each model by subtracting a base model. Merges the task vectors linearly and adds back the base. Works great for models that were fine tuned from a common ancestor. Also a super useful mental framework for several of the more involved merge methods.
Expand Down
126 changes: 126 additions & 0 deletions mergekit/merge_methods/nearswap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (C) 2025 Charles Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

from typing import Any, Dict, List, Optional, Union

import torch

from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.merge_methods.rectify_embed import rectify_embed_sizes


class NearSwapTask(Task[torch.Tensor]):
gather_tensors: MergeTensorInput
base_model: ModelReference
t: float
weight_info: WeightInfo

def uses_accelerator(self) -> bool:
return True

def arguments(self) -> Dict[str, Task]:
return {"tensors": self.gather_tensors}

def execute(self, tensors: Dict[ModelReference, torch.Tensor]) -> torch.Tensor:
if self.t <= 0:
raise RuntimeError(f"Threshold cannot be <= zero, got {self.t}")
if len(tensors) == 1:
return list(tensors.values())[0]
elif len(tensors) != 2:
raise RuntimeError(
f"Nearswap merge expects exactly two models, got {len(tensors)}"
)
elif self.base_model not in tensors:
raise RuntimeError("Base model not in input tensors")

[a, b] = list(tensors.items())
if a[0] != self.base_model:
[a, b] = [b, a]
prepped_tensors = [a[1], b[1]]

rectify_embed_sizes(self.weight_info, prepped_tensors)

return (
nearswap(
self.t,
prepped_tensors[0],
prepped_tensors[1],
)
.to(prepped_tensors[0].dtype)
.to(prepped_tensors[0].device)
)


class NearSwapMerge(MergeMethod):
def name(self) -> str:
return "nearswap"

def pretty_name(self) -> Optional[str]:
return "NearSwap"

def reference_url(self) -> Optional[str]:
return "https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001"

def parameters(self) -> List[ConfigParameterDef]:
return [ConfigParameterDef(name="t", required=True)]

def make_task(
self,
*,
output_weight: WeightInfo,
tensors: MergeTensorInput,
parameters: ImmutableMap[str, Any],
base_model: Optional[ModelReference],
**_kwargs,
) -> Task:
return NearSwapTask(
gather_tensors=tensors,
base_model=base_model,
weight_info=output_weight,
t=parameters["t"],
)


def nearswap(t: float, v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor:
"""
NearSwap implementation using PyTorch.

Adapted from: https://huggingface.co/alchemonaut/QuartetAnemoi-70B-t0.0001

Parameters:
t (float): The sameness threshold.
v0 (torch.Tensor): Weights from the base model.
v1 (torch.Tensor): Weights from the secondary model.

Returns:
torch.Tensor: Resulting interpolated weights.
"""
# Compute the absolute difference
lweight = torch.abs(v0 - v1)

# Compute the interpolation factor
lweight = t / lweight
lweight = torch.nan_to_num(lweight, nan=1.0, posinf=1.0, neginf=1.0)
lweight = torch.clamp(lweight, min=0.0, max=1.0)

# Linearly interpolate between v0 and v1
return lweight * v1 + (1 - lweight) * v0
2 changes: 2 additions & 0 deletions mergekit/merge_methods/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from mergekit.merge_methods.linear import LinearMerge
from mergekit.merge_methods.model_stock import ModelStockMerge
from mergekit.merge_methods.nearswap import NearSwapMerge
from mergekit.merge_methods.nuslerp import NuSlerpMerge
from mergekit.merge_methods.passthrough import PassthroughMerge
from mergekit.merge_methods.sce import SCEMerge
Expand All @@ -35,6 +36,7 @@
PassthroughMerge(),
ModelStockMerge(),
SCEMerge(),
NearSwapMerge(),
# generalized task arithmetic methods
GeneralizedTaskArithmeticMerge(
consensus_method=None,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_basic_merges.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,13 @@ def test_slerp_merge(self, model_a, model_b):
config.parameters = {"t": 0.35}
run_and_check_merge(config)

def test_nearswap_merge(self, model_a, model_b):
config = self.two_model_config(
model_a, model_b, merge_method="nearswap", base_model=model_a
)
config.parameters = {"t": 0.0001}
run_and_check_merge(config)

def test_nuslerp_merges(self, model_a, model_b, model_c):
for base_model in [None, model_c]:
for row_wise in [False, True]:
Expand Down
Loading