Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of PCB merge method #432

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ A quick overview of the currently supported merge methods:
| [Model Stock](https://arxiv.org/abs/2403.19522) | `model_stock` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) | `della` | ✅ | ✅ |
| [DELLA](https://arxiv.org/abs/2406.11617) [Task Arithmetic](https://arxiv.org/abs/2212.04089) | `della_linear` | ✅ | ✅ |
| [PCB](https://arxiv.org/abs/2410.02396) | `pcb` | ✅ | ✅ |

### Linear

The classic merge method - a simple weighted average.
Expand Down Expand Up @@ -195,10 +197,20 @@ Parameters:
Building upon DARE, DELLA uses adaptive pruning based on parameter magnitudes. DELLA first ranks parameters in each row of delta parameters and assigns drop probabilities inversely proportional to their magnitudes. This allows it to retain more important changes while reducing interference. After pruning, it rescales the remaining parameters similar to [DARE](#dare). DELLA can be used with (`della`) or without (`della_linear`) the sign elect step of TIES

Parameters: same as [Linear](#linear), plus:

- `density` - fraction of weights in differences from the base model to retain
- `epsilon` - maximum change in drop probability based on magnitude. Drop probabilities assigned will range from `density - epsilon` to `density + epsilon`. (When selecting values for `density` and `epsilon`, ensure that the range of probabilities falls within 0 to 1)
- `lambda` - scaling factor for the final merged delta parameters before merging with the base parameters.

### [PCB](https://arxiv.org/abs/2410.02396)

PCB is a heuristic approach to determine relative importance of parameters in each task vector. It uses terms for both intra-task and inter-task importance to determine both weighting and sparsification of each parameter.

Parameters:

- `density` - fraction of weights in differences from the base model to retain
- `weight` - total weight at which to apply the final combined task vector.

## LoRA extraction

Mergekit allows extracting PEFT-compatible low-rank approximations of finetuned models.
Expand Down Expand Up @@ -241,7 +253,6 @@ Or download your merge:

`!arcee merging download bio-merge`


## Citation

We now have a [paper](https://arxiv.org/abs/2403.13257) you can cite for the MergeKit library:
Expand Down
3 changes: 3 additions & 0 deletions mergekit/card.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def method_md(merge_method: str) -> str:
"dare_ties": "[DARE](https://arxiv.org/abs/2311.03099) [TIES](https://arxiv.org/abs/2306.01708)",
"dare_linear": "linear [DARE](https://arxiv.org/abs/2311.03099)",
"model_stock": "[Model Stock](https://arxiv.org/abs/2403.19522)",
"della": "[DELLA](https://arxiv.org/abs/2406.11617)",
"della_linear": "linear [DELLA](https://arxiv.org/abs/2406.11617)",
"pcb": "[PCB](https://arxiv.org/abs/2410.02396)",
}
return methods.get(merge_method, merge_method)

Expand Down
6 changes: 4 additions & 2 deletions mergekit/merge_methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from mergekit.merge_methods.linear import LinearMerge
from mergekit.merge_methods.model_stock import ModelStockMerge
from mergekit.merge_methods.passthrough import PassthroughMerge
from mergekit.merge_methods.pcb import PCBMerge
from mergekit.merge_methods.slerp import SlerpMerge
from mergekit.merge_methods.tokenizer_permute import TokenizerPermutationMerge

Expand Down Expand Up @@ -77,22 +78,23 @@ def get(method: str) -> MergeMethod:
)
elif method == "model_stock":
return ModelStockMerge()

elif method == "della":
return GeneralizedTaskArithmeticMerge(
consensus_method=ConsensusMethod.sum,
sparsification_method=SparsificationMethod.rank_magnitude_sampling,
default_normalize=True,
default_rescale=True,
)

elif method == "della_linear":
return GeneralizedTaskArithmeticMerge(
consensus_method=None,
sparsification_method=SparsificationMethod.rank_magnitude_sampling,
default_normalize=False,
default_rescale=True,
)
elif method == "pcb":
return PCBMerge()

raise RuntimeError(f"Unimplemented merge method {method}")


Expand Down
118 changes: 118 additions & 0 deletions mergekit/merge_methods/pcb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (C) 2024 Charles O. Goddard
#
# This software is free software: you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This software is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see http://www.gnu.org/licenses/.

import logging
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.nn.functional as F
from pydantic import BaseModel
from typing_extensions import Literal

from mergekit.architecture import WeightInfo
from mergekit.common import ImmutableMap, ModelReference
from mergekit.graph import Task
from mergekit.merge_methods.base import (
ConfigParameterDef,
MergeMethod,
MergeTensorInput,
)
from mergekit.merge_methods.generalized_task_arithmetic import get_task_vectors


class PCBMerge(MergeMethod):
def parameters(self) -> List[ConfigParameterDef]:
return [
ConfigParameterDef(name="density", required=True),
ConfigParameterDef(name="weight", required=False, default_value=1.0),
]

def make_task(
self,
output_weight: WeightInfo,
tensors: MergeTensorInput,
base_model: Optional[ModelReference],
parameters: ImmutableMap[str, Any],
**kwargs,
) -> Task[torch.Tensor]:
return PCBMergeTask(
output_weight=output_weight,
tensors=tensors,
base_model=base_model,
density=parameters["density"],
weight=parameters["weight"],
)


class PCBMergeTask(Task[torch.Tensor]):
output_weight: WeightInfo
tensors: MergeTensorInput
base_model: Optional[ModelReference]
density: float
weight: float

def uses_accelerator(self) -> bool:
return True

def arguments(self) -> Dict[str, Task]:
return {"tensors": self.tensors}

def execute(
self,
tensors: Dict[ModelReference, torch.Tensor],
**_kwargs,
) -> torch.Tensor:
# collect task vectors
tv_info, base = get_task_vectors(
self.output_weight,
self.base_model,
tensors,
tensor_parameters=ImmutableMap({model: {} for model in tensors}),
)
if not tv_info:
return base

n = len(tv_info)
tvs = torch.stack([tv["delta"] for tv in tv_info], dim=0)
tvs_flat = tvs.view(n, -1)

# $b_i = b_{intra, i} \odot b_{inter, i}$
# $b_{intra, i} = Softmax(N \cdot Norm(\delta_i \odot \delta_i))$
norm_tvs_sqr = F.normalize(tvs_flat * tvs_flat, dim=1)
b_intra = F.softmax(n * norm_tvs_sqr, dim=1)

# $b_{inter, i} = \sum_{j = 1}^{n} Softmax(Norm(\delta_i \odot \delta_j))$
b_inter = torch.zeros_like(tvs_flat)
for i in range(n):
inter_prod = tvs_flat[i] * tvs_flat
inter_norm = F.normalize(inter_prod, dim=1)
b_inter[i] = F.softmax(inter_norm, dim=1).sum(dim=0)

b = b_intra * b_inter
k = int(tvs_flat.shape[1] * self.density)
# $m_i = b_i \geq sorted(b_i)[k]$
# threshold = torch.kthvalue(b, k).values
# m = (b >= threshold.unsqueeze(1)).float()
_, indices = torch.topk(b, k, dim=1)
m = torch.zeros_like(b)
m.scatter_(1, indices, 1)

# $\hat{b}_i = b_i \odot m_i$
b_hat = b * m

weights = b_hat / torch.sum(b_hat)
final_delta = torch.sum(tvs_flat * weights, dim=0).view(tvs.shape[1:])
return base + self.weight * final_delta
20 changes: 20 additions & 0 deletions tests/test_basic_merges.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,26 @@ def test_model_stock_merge(self, model_a, model_b, model_c):
)
run_and_check_merge(config)

def test_della_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a,
model_b,
merge_method="della",
base_model=model_c,
params={"density": 0.7, "epsilon": 0.15, "lambda": 1.0},
)
run_and_check_merge(config)

def test_pcb_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_a,
model_b,
merge_method="pcb",
base_model=model_c,
params={"density": 0.8, "weight": 1.0},
)
run_and_check_merge(config)

def test_model_stock_filterwise_merge(self, model_a, model_b, model_c):
config = self.two_model_config(
model_b,
Expand Down
Loading