From 8147b1a34a32110f4b15178915ab3fe86690a1c8 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Mon, 29 Jan 2024 18:27:01 -0800 Subject: [PATCH 1/7] WIP multi-module architecture support --- mergekit/architecture/__init__.py | 54 ++++ mergekit/architecture/base.py | 136 ++++++++++ .../decoder_only.py} | 236 ++++------------- mergekit/config.py | 63 ++++- mergekit/merge.py | 32 ++- mergekit/plan.py | 248 ++++++++++++------ tests/common.py | 2 +- 7 files changed, 479 insertions(+), 292 deletions(-) create mode 100644 mergekit/architecture/__init__.py create mode 100644 mergekit/architecture/base.py rename mergekit/{architecture.py => architecture/decoder_only.py} (56%) diff --git a/mergekit/architecture/__init__.py b/mergekit/architecture/__init__.py new file mode 100644 index 00000000..1061aa84 --- /dev/null +++ b/mergekit/architecture/__init__.py @@ -0,0 +1,54 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + + +from transformers import PretrainedConfig + +from mergekit.architecture.base import ( + ModelArchitecture, + ModuleArchitecture, + ModuleDefinition, + StaticLayeredModuleArchitecture, + WeightInfo, +) +from mergekit.architecture.decoder_only import get_decoder_only_arch + + +def get_architecture_info(config: PretrainedConfig) -> ModelArchitecture: + if len(config.architectures) != 1: + raise RuntimeError("More than one architecture in config?") + arch_name = config.architectures[0] + + if decoder := get_decoder_only_arch(arch_name, config=config): + if isinstance(decoder, StaticLayeredModuleArchitecture): + num_layers = getattr(config, decoder.num_layers_config_key()) + decoder = StaticLayeredModuleArchitecture( + **decoder.model_dump(exclude=["configured_num_layers"]), + configured_num_layers=num_layers, + ) + return ModelArchitecture( + modules={"decoder": ModuleDefinition(architecture=decoder)} + ) + + raise RuntimeError(f"Unsupported architecture {arch_name}") + + +__all__ = [ + "ModelArchitecture", + "ModuleArchitecture", + "ModuleDefinition", + "WeightInfo", + "get_architecture_info", +] diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py new file mode 100644 index 00000000..9835e96a --- /dev/null +++ b/mergekit/architecture/base.py @@ -0,0 +1,136 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from pydantic import BaseModel + + +class WeightInfo(BaseModel): + name: str + is_embed: bool = False + + def prefixed_name(self, prefix: Optional[str] = None): + if prefix: + return prefix + self.name + return self.name + + +class ModuleArchitecture(ABC): + @abstractmethod + def num_layers(self) -> int: + """Return the number of layers in this module.""" + ... + + @abstractmethod + def layer_weights(self, index: int) -> Optional[List[WeightInfo]]: + """Return a list of all weights associated with a given layer.""" + ... + + @abstractmethod + def pre_weights(self) -> List[WeightInfo]: + """Return a list of all weights preceding the first layer.""" + ... + + @abstractmethod + def post_weights(self) -> List[WeightInfo]: + """Return a list of all weights following the final layer.""" + ... + + @abstractmethod + def slicable(self) -> bool: + """Return True if the architecture can be sliced meaningfully.""" + ... + + def num_layers_config_key(self) -> str: + """Key in config that represents number of layers""" + return "num_hidden_layers" + + def all_weights(self) -> List[str]: + num_layers = self.num_layers() + tensor_names = list(self.pre_weights()) + for layer_idx in range(num_layers): + tensor_names.extend(self.layer_weights(index=layer_idx)) + tensor_names.extend(self.post_weights()) + return [ti.name if isinstance(ti, WeightInfo) else ti for ti in tensor_names] + + +class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True): + architecture: ModuleArchitecture + weight_prefix: Optional[str] = None + config_prefix: Optional[str] = None + subfolder: Optional[str] = None + + +class ModelArchitecture(BaseModel, frozen=True, arbitrary_types_allowed=True): + modules: Dict[str, ModuleDefinition] + + def all_weights(self) -> List[str]: + tensor_names = [] + for module in self.modules.values(): + prefix = module.weight_prefix or "" + for name in module.architecture.all_weights(): + tensor_names.append(prefix + name) + return tensor_names + + +class StaticLayeredModuleArchitecture(ModuleArchitecture, BaseModel, frozen=True): + name: str + + pre_weight_names: List[str] + post_weight_names: List[str] + embed_weight_names: List[str] + layer_prefix_format: str + layer_weight_suffixes: List[str] + num_layers_key: Optional[str] = None + is_slicable: bool = True + configured_num_layers: Optional[int] = None + + def num_layers(self) -> int: + if not self.configured_num_layers: + raise RuntimeError( + "num_layers() called on module with no configured_num_layers set" + ) + return self.configured_num_layers + + def layer_weights(self, index: int) -> Optional[List[WeightInfo]]: + if index >= self.configured_num_layers: + return None + res = [] + for suffix in self.layer_weight_suffixes: + name = self.layer_prefix_format.format(idx=index) + "." + suffix + res.append(WeightInfo(name=name, is_embed=name in self.embed_weight_names)) + return res + + def pre_weights(self) -> List[WeightInfo]: + return [ + WeightInfo(name=name, is_embed=name in self.embed_weight_names) + for name in self.pre_weight_names + ] + + def post_weights(self) -> List[WeightInfo]: + return [ + WeightInfo(name=name, is_embed=name in self.embed_weight_names) + for name in self.post_weight_names + ] + + def num_layers_config_key(self) -> str: + if self.num_layers_key: + return self.num_layers_key + return super().num_layers_config_key() + + def slicable(self) -> bool: + return self.is_slicable diff --git a/mergekit/architecture.py b/mergekit/architecture/decoder_only.py similarity index 56% rename from mergekit/architecture.py rename to mergekit/architecture/decoder_only.py index 658f99ac..ce60f8d3 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture/decoder_only.py @@ -13,85 +13,17 @@ # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. -from abc import ABC, abstractmethod -from typing import ClassVar, List, Optional +from typing import List, Optional -from pydantic import BaseModel from transformers import PretrainedConfig +from mergekit.architecture.base import ( + ModuleArchitecture, + StaticLayeredModuleArchitecture, + WeightInfo, +) -class ArchitectureInfo(ABC): - @abstractmethod - def pre_weights(self) -> List[str]: - """Return a list of all weights preceding the first layer.""" - ... - - @abstractmethod - def post_weights(self) -> List[str]: - """Return a list of all weights following the final layer.""" - ... - - @abstractmethod - def layer_weight_formats(self) -> List[str]: - """Return a list of format strings all weights associated with a layer.""" - ... - - @abstractmethod - def embed_weights(self) -> List[str]: - ... - - def num_layers(self, config: PretrainedConfig) -> int: - return config.num_hidden_layers - - def num_layers_config_key(self) -> str: - """Key in config that represents number of layers""" - return "num_hidden_layers" - - -class StaticTensorNames(ArchitectureInfo, BaseModel, frozen=True): - name: str - - pre_weight_names: List[str] # weights applied before first layer - post_weight_names: List[str] # weights applied after last layer - embed_weight_names: List[str] # weights for embed/lm_head - layer_prefix_format: str - layer_weight_suffixes: List[str] - num_layers_key: Optional[str] = None - - def pre_weights(self) -> List[str]: - return self.pre_weight_names - - def post_weights(self) -> List[str]: - return self.post_weight_names - - def embed_weights(self) -> List[str]: - return self.embed_weight_names - - def layer_weight_formats(self) -> List[str]: - res = [] - for suffix in self.layer_weight_suffixes: - res.append(self.layer_prefix_format + "." + suffix) - return res - - def num_layers_config_key(self) -> str: - if self.num_layers_key: - return self.num_layers_key - return super().num_layers_config_key() - - def num_layers(self, config: PretrainedConfig) -> int: - return getattr(config, self.num_layers_config_key()) - - def all_weights(self, config: PretrainedConfig) -> List[str]: - num_layers = self.num_layers(config) - tensor_names = list(self.pre_weights()) - for layer_idx in range(num_layers): - for f in self.layer_weight_formats(): - tensor_names.append(f.format(idx=layer_idx)) - tensor_names.extend(self.post_weights()) - return tensor_names - - -LLAMA_INFO = StaticTensorNames( +LLAMA_INFO = StaticLayeredModuleArchitecture( name="LlamaForCausalLM", pre_weight_names=["model.embed_tokens.weight"], post_weight_names=["model.norm.weight", "lm_head.weight"], @@ -110,48 +42,14 @@ def all_weights(self, config: PretrainedConfig) -> List[str]: ], ) -MISTRAL_INFO = StaticTensorNames( +MISTRAL_INFO = StaticLayeredModuleArchitecture( name="MistralForCausalLM", # lol **LLAMA_INFO.model_dump(exclude=["name"]), ) -class MixtralTensorNames(ArchitectureInfo, BaseModel): - ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" - num_local_experts: int - - @classmethod - def from_config(cls, config: PretrainedConfig): - return MixtralTensorNames(num_local_experts=config.num_local_experts) - - def pre_weights(self) -> List[str]: - return MISTRAL_INFO.pre_weights() - - def post_weights(self) -> List[str]: - return MISTRAL_INFO.post_weights() - - def embed_weights(self) -> List[str]: - return MISTRAL_INFO.embed_weights() - - def num_layers_config_key(self) -> str: - return MISTRAL_INFO.num_layers_config_key() - - def layer_weight_formats(self) -> List[str]: - num_experts = self.num_local_experts - res = [fmt for fmt in MISTRAL_INFO.layer_weight_formats() if ".mlp." not in fmt] - for expert_idx in range(num_experts): - for param in ("w1", "w2", "w3"): - fmt = ( - MISTRAL_INFO.layer_prefix_format - + f".block_sparse_moe.experts.{expert_idx}.{param}.weight" - ) - res.append(fmt) - res.append(MISTRAL_INFO.layer_prefix_format + ".block_sparse_moe.gate.weight") - return res - - -STABLELM_INFO = StaticTensorNames( +STABLELM_INFO = StaticLayeredModuleArchitecture( name="StableLMEpochForCausalLM", post_weight_names=LLAMA_INFO.post_weight_names + ["model.norm.bias"], layer_weight_suffixes=LLAMA_INFO.layer_weight_suffixes @@ -164,7 +62,7 @@ def layer_weight_formats(self) -> List[str]: ), ) -GPT_NEOX_INFO = StaticTensorNames( +GPT_NEOX_INFO = StaticLayeredModuleArchitecture( name="GPTNeoXForCausalLM", pre_weight_names=["gpt_neox.embed_in.weight"], post_weight_names=[ @@ -191,7 +89,7 @@ def layer_weight_formats(self) -> List[str]: + ["attention.bias", "attention.masked_bias", "attention.rotary_emb.inv_freq"], ) -GPT2_INFO = StaticTensorNames( +GPT2_INFO = StaticLayeredModuleArchitecture( name="GPT2LMHeadModel", pre_weight_names=["wte.weight", "wpe.weight"], post_weight_names=["ln_f.weight", "ln_f.bias"], @@ -216,7 +114,7 @@ def layer_weight_formats(self) -> List[str]: num_layers_key="n_layer", ) -JAIS_INFO = StaticTensorNames( +JAIS_INFO = StaticLayeredModuleArchitecture( name="JAISLMHeadModel", pre_weight_names=["transformer.wte.weight", "transformer.relative_pe.slopes"], post_weight_names=["transformer.ln_f.weight", "transformer.ln_f.bias"], @@ -241,7 +139,7 @@ def layer_weight_formats(self) -> List[str]: num_layers_key="n_layer", ) -GPT2_SEQCLASS_INFO = StaticTensorNames( +GPT2_SEQCLASS_INFO = StaticLayeredModuleArchitecture( name="GPT2ForSequenceClassification", pre_weight_names=["transformer.wte.weight", "transformer.wpe.weight"], post_weight_names=[ @@ -256,7 +154,7 @@ def layer_weight_formats(self) -> List[str]: ) -QWEN_INFO = StaticTensorNames( +QWEN_INFO = StaticLayeredModuleArchitecture( name="QWenLMHeadModel", pre_weight_names=["transformer.wte.weight"], post_weight_names=["transformer.ln_f.weight", "lm_head.weight"], @@ -274,7 +172,7 @@ def layer_weight_formats(self) -> List[str]: ], ) -CHATGLM_INFO = StaticTensorNames( +CHATGLM_INFO = StaticLayeredModuleArchitecture( name="ChatGLMModel", pre_weight_names=[ "transformer.embedding.word_embeddings.weight", @@ -300,57 +198,37 @@ def layer_weight_formats(self) -> List[str]: ], ) -FALCON_INFO = StaticTensorNames( - name="FalconForCausalLM", - pre_weight_names=["transformer.word_embeddings.weight"], - post_weight_names=[ - "transformer.ln_f.weight", - "transformer.ln_f.bias", - "lm_head.weight", - ], - embed_weight_names=["transformer.word_embeddings.weight", "lm_head.weight"], - layer_prefix_format="transformer.h.{idx}", - layer_weight_suffixes=[ - "ln_attn.bias", - "ln_attn.weight", - "ln_mlp.bias", - "ln_mlp.weight", - "mlp.dense_4h_to_h.weight", - "mlp.dense_h_to_4h.weight", - "self_attention.dense.weight", - "self_attention.query_key_value.weight", - ], -) +class PhiDecoderArchitecture(ModuleArchitecture): + architecture_name: str = "MixFormerSequentialForCausalLM" + num_configured_layers: int -class PhiTensorNames(ArchitectureInfo, BaseModel): - ARCHITECTURE_NAME: ClassVar[str] = "MixFormerSequentialForCausalLM" - n_layer: int + def __init__(self, config: PretrainedConfig): + self.num_configured_layers = getattr(config, self.num_layers_config_key) - def from_config(cls, config: PretrainedConfig): - return PhiTensorNames(n_layer=config.n_layer) + def __eq__(self, rhs: ModuleArchitecture): + if not isinstance(rhs, PhiDecoderArchitecture): + return False + return self.num_layers() == rhs.num_layers() - def pre_weights(self) -> List[str]: - return ["layers.0.wte.weight"] + def pre_weights(self) -> List[WeightInfo]: + return [WeightInfo(name="layers.0.wte.weight", is_embed=True)] - def post_weights(self) -> List[str]: - fake_layer_idx = self.n_layer + def post_weights(self) -> List[WeightInfo]: + fake_layer_idx = self.num_configured_layers + 1 return [ - f"layers.{fake_layer_idx}.{suffix}" + WeightInfo( + name=f"layers.{fake_layer_idx}.{suffix}", is_embed="linear" in suffix + ) for suffix in ["linear.bias", "linear.weight", "ln.bias", "ln.weight"] ] - def embed_weights(self) -> List[str]: - fake_layer_idx = self.n_layer - return [ - "layers.0.wte.weight", - f"layers.{fake_layer_idx}.linear.weight", - f"layers.{fake_layer_idx}.linear.bias", - ] + def num_layers(self, config: PretrainedConfig) -> int: + return self.num_configured_layers - def layer_weight_formats(self) -> List[str]: + def layer_weights(self, index: int) -> Optional[List[WeightInfo]]: return [ - ("layers.{idx}." + suffix) + WeightInfo(name=("layers.{idx}." + suffix).format(idx=index)) for suffix in [ "ln.bias", "ln.weight", @@ -366,14 +244,14 @@ def layer_weight_formats(self) -> List[str]: ] ] - def num_layers(self, config: PretrainedConfig) -> int: - return config.n_layer + def slicable(self) -> bool: + return True def num_layers_config_key(self) -> str: return "n_layer" -PHI2_INFO = StaticTensorNames( +PHI2_INFO = StaticLayeredModuleArchitecture( name="PhiForCausalLM", pre_weight_names=["transformer.embd.wte.weight"], post_weight_names=[ @@ -400,7 +278,7 @@ def num_layers_config_key(self) -> str: ) -PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticTensorNames( +PHI2_INFO_AGAIN_BUT_DIFFERENT = StaticLayeredModuleArchitecture( name="PhiForCausalLM", pre_weight_names=["model.embed_tokens.weight"], post_weight_names=[ @@ -430,33 +308,11 @@ def num_layers_config_key(self) -> str: ) -BAICHUAN_INFO = StaticTensorNames( - name="BaichuanForCausalLM", - pre_weight_names=["model.embed_tokens.weight"], - post_weight_names=["model.norm.weight", "lm_head.weight"], - embed_weight_names=["model.embed_tokens.weight", "lm_head.weight"], - layer_prefix_format="model.layers.{idx}", - layer_weight_suffixes=[ - "input_layernorm.weight", - "self_attn.W_pack.weight", - "self_attn.o_proj.weight", - "post_attention_layernorm.weight", - "mlp.gate_proj.weight", - "mlp.down_proj.weight", - "mlp.up_proj.weight", - ], -) - - -def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames: - if len(config.architectures) != 1: - raise RuntimeError("More than one architecture in config?") - - arch_name = config.architectures[0] - if arch_name == PhiTensorNames.ARCHITECTURE_NAME: - return PhiTensorNames.from_config(config) - if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: - return MixtralTensorNames.from_config(config) +def get_decoder_only_arch( + arch_name: str, config: PretrainedConfig +) -> Optional[ModuleArchitecture]: + if arch_name == PhiDecoderArchitecture.architecture_name: + return PhiDecoderArchitecture(config) if arch_name == PHI2_INFO.name: if config.model_type == "phi-msft": @@ -474,11 +330,7 @@ def get_architecture_info(config: PretrainedConfig) -> StaticTensorNames: CHATGLM_INFO, STABLELM_INFO, JAIS_INFO, - BAICHUAN_INFO, - FALCON_INFO, ] for arch in supported: if arch.name == arch_name: return arch - - raise RuntimeError(f"Unsupported architecture {arch_name}") diff --git a/mergekit/config.py b/mergekit/config.py index cff31921..45217629 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -81,14 +81,28 @@ class OutputSliceDefinition(BaseModel): parameters: Optional[Dict[str, ParameterSetting]] = None -class MergeConfiguration(BaseModel): - merge_method: str +class OutputModuleDefinition(BaseModel): slices: Optional[List[OutputSliceDefinition]] = None models: Optional[List[InputModelDefinition]] = None parameters: Optional[Dict[str, ParameterSetting]] = None + + @model_validator(mode="after") + def validate_inputs(self): + if ((not self.slices) and (not self.models)) or (self.slices and self.models): + raise RuntimeError("Must specify either output slices or models to merge") + return self + + +class MergeConfiguration(BaseModel): + modules: Optional[Dict[str, OutputModuleDefinition]] = None + slices: Optional[List[OutputSliceDefinition]] = None + models: Optional[List[InputModelDefinition]] = None + + merge_method: str base_model: Optional[ModelReference] = None dtype: Optional[str] = None tokenizer_source: Optional[str] = None + parameters: Optional[Dict[str, ParameterSetting]] = None def referenced_models(self) -> List[ModelReference]: models = set() @@ -101,12 +115,31 @@ def referenced_models(self) -> List[ModelReference]: for s in self.slices: for src in s.sources: models.add(src.model) + if self.modules: + for m in self.modules.values(): + if m.models: + for model_in in m.models: + models.add(model_in.model) + if m.slices: + for s in m.slices: + for src in s.sources: + models.add(src.model) return list(models) @model_validator(mode="after") def validate_inputs(self): - if ((not self.slices) and (not self.models)) or (self.slices and self.models): - raise RuntimeError("Must specify either output slices or models to merge") + set_ct = 0 + if self.modules: + set_ct += 1 + if self.slices: + set_ct += 1 + if self.models: + set_ct += 1 + + if set_ct != 1: + raise RuntimeError( + "Exactly one of 'models', 'slices', or 'models' must be present" + ) return self def to_yaml(self) -> str: @@ -121,6 +154,7 @@ class ConfigReader(BaseModel): t: float tensor_name: Optional[str] = None slice_out: Optional[OutputSliceDefinition] = None + module: Optional[OutputModuleDefinition] = None @property def base_model(self) -> Optional[ModelReference]: @@ -137,6 +171,7 @@ def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader": t=self.t, tensor_name=self.tensor_name, slice_out=slice, + module=self.module, ) def for_tensor(self, tensor_name: str) -> "ConfigReader": @@ -145,6 +180,7 @@ def for_tensor(self, tensor_name: str) -> "ConfigReader": t=self.t, tensor_name=tensor_name, slice_out=self.slice_out, + module=self.module, ) def with_t(self, t: float) -> "ConfigReader": @@ -153,6 +189,16 @@ def with_t(self, t: float) -> "ConfigReader": t=t, tensor_name=self.tensor_name, slice_out=self.slice_out, + module=self.module, + ) + + def for_module(self, module: OutputModuleDefinition) -> "ConfigReader": + return ConfigReader( + config=self.config, + t=self.t, + tensor_name=self.tensor_name, + slice_out=self.slice_out, + module=module, ) def parameter( @@ -179,6 +225,15 @@ def parameter( if value is not None: return value + if self.module and self.module.parameters and name in self.module.parameters: + value = evaluate_setting( + self.tensor_name, + self.module.parameters[name], + self.t, + ) + if value is not None: + return value + if self.config.parameters and name in self.config.parameters: value = evaluate_setting( self.tensor_name, diff --git a/mergekit/merge.py b/mergekit/merge.py index 7550e49f..28d7f947 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -20,7 +20,7 @@ import tqdm import transformers -from mergekit.architecture import ArchitectureInfo, get_architecture_info +from mergekit.architecture import ModelArchitecture, get_architecture_info from mergekit.card import generate_card from mergekit.config import MergeConfiguration from mergekit.graph import Executor @@ -141,7 +141,7 @@ def _get_donor_tokenizer( def _model_out_config( config: MergeConfiguration, - arch_info: ArchitectureInfo, + arch_info: ModelArchitecture, tokenizer: Optional[transformers.PreTrainedTokenizerBase] = None, trust_remote_code: bool = False, ) -> transformers.PretrainedConfig: @@ -162,17 +162,23 @@ def _model_out_config( exc_info=e, ) - try: - num_layers = sum( - s.sources[0].layer_range[1] - s.sources[0].layer_range[0] - for s in config.slices - ) - setattr(res, arch_info.num_layers_config_key(), num_layers) - except Exception as e: - logging.warning( - "Unable to set number of layers in output config - you may need to manually correct it.", - exc_info=e, - ) + for module_name, module_def in config.modules.items(): + module_info = arch_info.modules[module_name] + cfg_key = ( + module_info.config_prefix or "" + ) + module_info.architecture.num_layers_config_key() + try: + num_layers = sum( + s.sources[0].layer_range[1] - s.sources[0].layer_range[0] + for s in module_def.slices + ) + setattr(res, cfg_key, num_layers) + except Exception as e: + logging.warning( + f"Unable to set number of layers for module {module_name} in output config " + "- you may need to manually correct it.", + exc_info=e, + ) return res diff --git a/mergekit/plan.py b/mergekit/plan.py index fac2f216..b53d2a4c 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -14,15 +14,18 @@ # along with this program. If not, see http://www.gnu.org/licenses/. import logging -from typing import List, Optional +import os +from functools import lru_cache +from typing import Dict, List, Optional from mergekit import merge_methods -from mergekit.architecture import ArchitectureInfo +from mergekit.architecture import ModelArchitecture, WeightInfo from mergekit.common import ImmutableMap, ModelReference from mergekit.config import ( ConfigReader, InputSliceDefinition, MergeConfiguration, + OutputModuleDefinition, OutputSliceDefinition, ) from mergekit.graph import Task @@ -35,32 +38,29 @@ class MergePlanner: config: MergeConfiguration - arch_info: ArchitectureInfo - clone_tensors: bool - trust_remote_code: bool + arch_info: ModelArchitecture + options: MergeOptions + out_path: str _writer_task: TensorWriterTask + _tensor_save_tasks: Dict[TensorWriterTask, List[SaveTensor]] _method: MergeMethod _tasks: List[Task] = [] - _current_layers: int = 0 + _current_module_layers: int = 0 _tokenizer_task: Optional[BuildTokenizer] = None def __init__( self, config: MergeConfiguration, - arch_info: ArchitectureInfo, + arch_info: ModelArchitecture, out_path: str, options: MergeOptions, ): self.config = config self.arch_info = arch_info - self.clone_tensors = options.clone_tensors - self.trust_remote_code = options.trust_remote_code + self.options = options + self.out_path = out_path self._method = merge_methods.get(config.merge_method) - self._writer_task = TensorWriterTask( - out_path=out_path, - max_shard_size=options.out_shard_size, - safe_serialization=options.safe_serialization, - ) + self._tensor_save_tasks = {} if config.tokenizer_source: self._tokenizer_task = BuildTokenizer( @@ -73,45 +73,82 @@ def __init__( def normalize_config(self): base_model = self.config.base_model - # if models to merge are specified instead of output slices, compute them + # models -> modules.models if self.config.models: - if self.config.slices: + self.config.modules = {} + for module_name in self.arch_info.modules: + self.config.modules[module_name] = OutputModuleDefinition( + name=module_name, models=self.config.models + ) + self.config.models = None + + # slices -> modules.slices + if self.config.slices: + if len(self.arch_info.modules) != 1: raise RuntimeError( - "Must specify either models to merge or output slices" + "Model has multiple modules, must use modules: syntax" ) + module_name = list(self.arch_info.modules.keys())[0] + self.config.modules = { + module_name: OutputModuleDefinition(slices=self.config.slices) + } + self.config.slices = None - slices_in = [] - base_included = False + # modules.models -> modules.slices + for module_name in self.config.modules: + module_out = self.config.modules[module_name] + num_layers_key = ( + self.arch_info.modules[module_name].config_prefix or "" + ) + self.arch_info.modules[module_name].architecture.num_layers_config_key() - for model_in in self.config.models: - if base_model and model_in.model == base_model: - base_included = True + if module_out.models: + slices_in = [] + base_included = False - model_cfg = model_in.model.config( - trust_remote_code=self.trust_remote_code - ) - num_layers = self.arch_info.num_layers(model_cfg) - slices_in.append( - InputSliceDefinition( - layer_range=[0, num_layers], - model=model_in.model, - parameters=model_in.parameters, + for model_in in module_out.models: + if base_model and model_in.model == base_model: + base_included = True + + model_cfg = model_in.model.config( + trust_remote_code=self.options.trust_remote_code + ) + num_layers = int(getattr(model_cfg, num_layers_key)) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=model_in.model, + parameters=model_in.parameters, + ) ) - ) - if base_model and not base_included: - logging.info("Base model specified but not in input models - adding") - base_cfg = base_model.config(trust_remote_code=self.trust_remote_code) - num_layers = self.arch_info.num_layers(base_cfg) - slices_in.append( - InputSliceDefinition( - layer_range=[0, num_layers], - model=base_model, + if base_model and not base_included: + logging.info( + "Base model specified but not in input models - adding" + ) + base_cfg = base_model.config( + trust_remote_code=self.options.trust_remote_code + ) + num_layers = int(getattr(base_cfg, num_layers_key)) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=base_model, + ) ) - ) - self.config.slices = [OutputSliceDefinition(sources=slices_in)] - self.config.models = None + module_out.slices = [OutputSliceDefinition(sources=slices_in)] + module_out.models = None + + @lru_cache + def _tensor_writer(self, subfolder: Optional[str] = None): + path = self.out_path + if subfolder: + path = os.path.join(path, subfolder) + return TensorWriterTask( + out_path=path, + max_shard_size=self.options.out_shard_size, + safe_serialization=self.options.safe_serialization, + ) def plan_tensor( self, @@ -119,8 +156,9 @@ def plan_tensor( names_in: List[str], models: List[ModelReference], cfg_reader: ConfigReader, + tensor_writer: TensorWriterTask, + is_embed: bool = False, ): - is_embed = name in self.arch_info.embed_weights() tensor_merge_method = self._method if self._tokenizer_task and is_embed: tensor_merge_method = TokenizerPermutationMerge( @@ -166,9 +204,12 @@ def plan_tensor( save_task = SaveTensor( tensor_name=name, tensor_task=tensor_task, - writer_task=self._writer_task, - clone=self.clone_tensors, + writer_task=tensor_writer, + clone=self.options.clone_tensors, ) + if tensor_writer not in self._tensor_save_tasks: + self._tensor_save_tasks[tensor_writer] = [] + self._tensor_save_tasks[tensor_writer].append(save_task) self._tasks.append(save_task) def plan_layer( @@ -177,22 +218,41 @@ def plan_layer( layer_offset: int, t: float, cfg_reader: ConfigReader, + module_name: str, ): - for name_format in self.arch_info.layer_weight_formats(): - name_out = name_format.format(idx=self._current_layers) - names_in = [ - name_format.format(idx=s.layer_range[0] + layer_offset) for s in sources - ] - + module_arch_def = self.arch_info.modules[module_name] + weights_out: List[WeightInfo] = module_arch_def.architecture.layer_weights( + index=self._current_module_layers + ) + weights_in: List[List[WeightInfo]] = [ + module_arch_def.architecture.layer_weights( + index=s.layer_range[0] + layer_offset + ) + for s in sources + ] + for idx, w_o in enumerate(weights_out): self.plan_tensor( - name=name_out, - names_in=names_in, + name=w_o.prefixed_name(prefix=module_arch_def.weight_prefix), + names_in=[ + weights_in[j][idx].prefixed_name( + prefix=module_arch_def.weight_prefix + ) + for j in range(len(weights_in)) + ], models=[s.model for s in sources], cfg_reader=cfg_reader.with_t(t), + tensor_writer=self._tensor_writer(subfolder=module_arch_def.subfolder), + is_embed=w_o.is_embed, ) - self._current_layers += 1 - def plan_slice(self, definition: OutputSliceDefinition): + self._current_module_layers += 1 + + def plan_slice( + self, + definition: OutputSliceDefinition, + module_def: OutputModuleDefinition, + module_name: str, + ): slice_lengths = [ s.layer_range[1] - s.layer_range[0] for s in definition.sources ] @@ -202,7 +262,9 @@ def plan_slice(self, definition: OutputSliceDefinition): ) num_layers = slice_lengths[0] - cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0) + cfg_reader = ConfigReader( + config=self.config, slice_out=definition, module=module_def, t=0 + ) for idx in range(num_layers): # compute t for interpolated gradients if num_layers > 1: @@ -215,44 +277,66 @@ def plan_slice(self, definition: OutputSliceDefinition): layer_offset=idx, t=t, cfg_reader=cfg_reader, + module_name=module_name, ) - def plan(self): - self.normalize_config() - self._tasks = [] + def plan_module(self, module_name: str, definition: OutputModuleDefinition): + self._current_module_layers = 0 - for weight_name in self.arch_info.pre_weights(): + module_arch_def = self.arch_info.modules[module_name] + config_reader = ConfigReader(config=self.config, t=0, module=definition) + + for weight_info in module_arch_def.architecture.pre_weights(): + weight_name = weight_info.prefixed_name( + prefix=module_arch_def.weight_prefix + ) self.plan_tensor( weight_name, - [weight_name] * len(self.config.slices[0].sources), - [s.model for s in self.config.slices[0].sources], - ConfigReader( - config=self.config, - t=0, - tensor_name=weight_name, - ).for_out_slice(self.config.slices[0]), + [weight_name] * len(definition.slices[0].sources), + [s.model for s in definition.slices[0].sources], + config_reader.for_tensor(tensor_name=weight_name).for_out_slice( + definition.slices[0] + ), + tensor_writer=self._tensor_writer(subfolder=module_arch_def.subfolder), + is_embed=weight_info.is_embed, ) - for out_slice in self.config.slices: - self.plan_slice(out_slice) + for out_slice in definition.slices: + self.plan_slice( + out_slice, + module_def=definition, + module_name=module_name, + ) - for weight_name in self.arch_info.post_weights(): + for weight_info in module_arch_def.architecture.post_weights(): + weight_name = weight_info.prefixed_name( + prefix=module_arch_def.weight_prefix + ) self.plan_tensor( weight_name, - [weight_name] * len(self.config.slices[-1].sources), - [s.model for s in self.config.slices[-1].sources], - ConfigReader( - config=self.config, - t=1, - tensor_name=weight_name, - ).for_out_slice(self.config.slices[-1]), + [weight_name] * len(definition.slices[-1].sources), + [s.model for s in definition.slices[-1].sources], + config_reader.for_tensor(tensor_name=weight_name).for_out_slice( + definition.slices[-1] + ), + tensor_writer=self._tensor_writer(subfolder=module_arch_def.subfolder), + is_embed=weight_info.is_embed, ) - self._tasks.append( - FinalizeModel( - tensor_save_tasks=tuple(self._tasks), writer_task=self._writer_task + def plan(self): + self.normalize_config() + self._tasks = [] + + for module_name in self.config.modules: + self.plan_module(module_name, self.config.modules[module_name]) + + for writer in self._tensor_save_tasks: + self._tasks.append( + FinalizeModel( + tensor_save_tasks=tuple(self._tensor_save_tasks[writer]), + writer_task=writer, + ) ) - ) res = list(self._tasks) if self._tokenizer_task: res.append(self._tokenizer_task) diff --git a/tests/common.py b/tests/common.py index 55566ad6..cc871a3a 100644 --- a/tests/common.py +++ b/tests/common.py @@ -46,7 +46,7 @@ def run_and_check_merge( arch_info = get_architecture_info(config) index = ShardedTensorIndex.from_disk(tmpdir) - for tensor_name in arch_info.all_weights(config): + for tensor_name in arch_info.all_weights(): if tensor_name not in index.tensor_paths: raise RuntimeError(f"Output missing tensor {tensor_name}") From a898ad95378f89e744a0067bed4f7ed98b2910f0 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Mon, 29 Jan 2024 18:34:30 -0800 Subject: [PATCH 2/7] Cleanup --- mergekit/architecture/base.py | 28 ++++++++++++++++------------ tests/common.py | 6 +++--- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py index 9835e96a..0840c10a 100644 --- a/mergekit/architecture/base.py +++ b/mergekit/architecture/base.py @@ -59,13 +59,13 @@ def num_layers_config_key(self) -> str: """Key in config that represents number of layers""" return "num_hidden_layers" - def all_weights(self) -> List[str]: + def all_weights(self) -> List[WeightInfo]: num_layers = self.num_layers() - tensor_names = list(self.pre_weights()) + res = list(self.pre_weights()) for layer_idx in range(num_layers): - tensor_names.extend(self.layer_weights(index=layer_idx)) - tensor_names.extend(self.post_weights()) - return [ti.name if isinstance(ti, WeightInfo) else ti for ti in tensor_names] + res.extend(self.layer_weights(index=layer_idx)) + res.extend(self.post_weights()) + return res class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True): @@ -75,16 +75,20 @@ class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True): subfolder: Optional[str] = None -class ModelArchitecture(BaseModel, frozen=True, arbitrary_types_allowed=True): +class ModelArchitecture(BaseModel, frozen=True): modules: Dict[str, ModuleDefinition] - def all_weights(self) -> List[str]: - tensor_names = [] + def all_weights(self) -> List[WeightInfo]: + res = [] for module in self.modules.values(): - prefix = module.weight_prefix or "" - for name in module.architecture.all_weights(): - tensor_names.append(prefix + name) - return tensor_names + for weight_info in module.architecture.all_weights(): + res.append( + WeightInfo( + name=weight_info.prefixed_name(module.weight_prefix), + is_embed=weight_info.is_embed, + ) + ) + return res class StaticLayeredModuleArchitecture(ModuleArchitecture, BaseModel, frozen=True): diff --git a/tests/common.py b/tests/common.py index cc871a3a..e4b02374 100644 --- a/tests/common.py +++ b/tests/common.py @@ -46,9 +46,9 @@ def run_and_check_merge( arch_info = get_architecture_info(config) index = ShardedTensorIndex.from_disk(tmpdir) - for tensor_name in arch_info.all_weights(): - if tensor_name not in index.tensor_paths: - raise RuntimeError(f"Output missing tensor {tensor_name}") + for weight_info in arch_info.all_weights(): + if weight_info.name not in index.tensor_paths: + raise RuntimeError(f"Output missing tensor {weight_info.name}") if validate: validate(tmpdir) From 5ca1c51554684f163e65f525a04b094806ab0e04 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 5 Oct 2024 13:56:11 -0700 Subject: [PATCH 3/7] Fix --- mergekit/architecture/base.py | 10 ++++++++-- mergekit/config.py | 12 +++--------- mergekit/plan.py | 7 ++++--- tests/test_chat_template.py | 2 ++ 4 files changed, 17 insertions(+), 14 deletions(-) diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py index 01e6bde8..c3c62d5b 100644 --- a/mergekit/architecture/base.py +++ b/mergekit/architecture/base.py @@ -14,13 +14,14 @@ # along with this program. If not, see http://www.gnu.org/licenses/. from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from pydantic import BaseModel from transformers import PretrainedConfig from typing_extensions import Literal + class WeightInfo(BaseModel, frozen=True): """Information about an individual weight tensor in a model. @@ -37,6 +38,8 @@ class WeightInfo(BaseModel, frozen=True): Indicates whether the weight can be omitted from a model. aliases (Optional[List[str]]): List of alternative names for the weight, if applicable. + force_dtype (Optional[str]): + Mandatory dtype for the weight, if applicable. """ name: str @@ -44,7 +47,10 @@ class WeightInfo(BaseModel, frozen=True): input_space: Optional[str] = None output_space: Optional[str] = None optional: bool = False - aliases: Optional[List[str]] = None + aliases: Optional[Tuple[str, ...]] = None + force_dtype: Optional[str] = None + head_split: Literal[None, "input", "output"] = None + is_kq: Optional[bool] = False class ProceduralSpaceInfo(BaseModel, frozen=True): diff --git a/mergekit/config.py b/mergekit/config.py index b0e5b5a9..ed874852 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -102,9 +102,9 @@ class MergeConfiguration(BaseModel): merge_method: str base_model: Optional[ModelReference] = None dtype: Optional[str] = None - tokenizer_source: Union[ - Literal["union"], Literal["base"], ModelReference, None - ] = None + tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference, None] = ( + None + ) tokenizer: Optional[TokenizerConfig] = None chat_template: Optional[str] = None out_dtype: Optional[str] = None @@ -154,12 +154,6 @@ def validate_tokenizer(self): raise RuntimeError("Cannot specify both tokenizer_source and tokenizer") return self - @model_validator(mode="after") - def validate_tokenizer(self): - if self.tokenizer_source and self.tokenizer: - raise RuntimeError("Cannot specify both tokenizer_source and tokenizer") - return self - def to_yaml(self) -> str: return yaml.dump( self.model_dump(exclude_defaults=True, mode="json"), diff --git a/mergekit/plan.py b/mergekit/plan.py index 4be38b49..2f370422 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -54,7 +54,6 @@ class MergePlanner: _tensors: List[Tuple[WeightInfo, Task]] _current_module_layers: int = 0 _tokenizer_task: Optional[BuildTokenizer] = None - _tensor_save_tasks: Dict[TensorWriterTask, List[SaveTensor]] def __init__( self, @@ -68,6 +67,7 @@ def __init__( self.options = options self.out_model_config = out_model_config self._method = merge_methods.get(config.merge_method) + self._tensors = [] token_cfg = {} tokenizer_source = config.tokenizer_source @@ -220,8 +220,9 @@ def plan_tensor( base_model=base_model, ) + print(f"output_weight: {repr(weight)} ({type(weight)})") tensor_task = tensor_merge_method.make_task( - output_weight=weight, + output_weight=weight.model_dump(), tensors=tensor_input_task, parameters=ImmutableMap(data=global_params), tensor_parameters=ImmutableMap( @@ -375,7 +376,7 @@ def plan_in_memory(self) -> List[ReturnTensor]: for w, t in self._tensors ] - def plan(self): + def _plan(self): self.normalize_config() self._tasks = [] diff --git a/tests/test_chat_template.py b/tests/test_chat_template.py index 1db8552c..f678b25e 100644 --- a/tests/test_chat_template.py +++ b/tests/test_chat_template.py @@ -4,6 +4,8 @@ from transformers import AutoTokenizer from mergekit.config import InputModelDefinition, MergeConfiguration +from test_tokenizer import model_base # pylint: disable=unused-import +from test_basic_merges import model_b # pylint: disable=unused-import def check_chat_template(model_path: str, needle: Optional[str] = None): From 09249932b4ccfed9c5a453111334504052b14e30 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 5 Oct 2024 14:00:03 -0700 Subject: [PATCH 4/7] Cleanup --- mergekit/architecture/base.py | 1 - mergekit/config.py | 6 +++--- mergekit/plan.py | 2 +- tests/test_chat_template.py | 18 +++++++++++++++--- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py index c3c62d5b..4ac3ef63 100644 --- a/mergekit/architecture/base.py +++ b/mergekit/architecture/base.py @@ -21,7 +21,6 @@ from typing_extensions import Literal - class WeightInfo(BaseModel, frozen=True): """Information about an individual weight tensor in a model. diff --git a/mergekit/config.py b/mergekit/config.py index ed874852..6edb579e 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -102,9 +102,9 @@ class MergeConfiguration(BaseModel): merge_method: str base_model: Optional[ModelReference] = None dtype: Optional[str] = None - tokenizer_source: Union[Literal["union"], Literal["base"], ModelReference, None] = ( - None - ) + tokenizer_source: Union[ + Literal["union"], Literal["base"], ModelReference, None + ] = None tokenizer: Optional[TokenizerConfig] = None chat_template: Optional[str] = None out_dtype: Optional[str] = None diff --git a/mergekit/plan.py b/mergekit/plan.py index 2f370422..81e3d476 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -15,7 +15,7 @@ import logging from functools import lru_cache -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple from mergekit import merge_methods from mergekit.architecture import ( diff --git a/tests/test_chat_template.py b/tests/test_chat_template.py index f678b25e..2bd41cde 100644 --- a/tests/test_chat_template.py +++ b/tests/test_chat_template.py @@ -1,11 +1,23 @@ from typing import Optional -from common import run_and_check_merge +import pytest +from common import make_picollama, run_and_check_merge +from test_tokenizer import make_tokenizer from transformers import AutoTokenizer from mergekit.config import InputModelDefinition, MergeConfiguration -from test_tokenizer import model_base # pylint: disable=unused-import -from test_basic_merges import model_b # pylint: disable=unused-import + + +@pytest.fixture(scope="session") +def model_base(tmp_path_factory): + model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64) + make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path) + return model_path + + +@pytest.fixture(scope="session") +def model_b(tmp_path_factory): + return make_picollama(tmp_path_factory.mktemp("model_b")) def check_chat_template(model_path: str, needle: Optional[str] = None): From e59e366f4b34a9eb8c65bb8b9b2a2b6725d9ee66 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 5 Oct 2024 14:02:16 -0700 Subject: [PATCH 5/7] Fix duplicated code --- mergekit/architecture/decoder_only.py | 137 +------------------------- mergekit/plan.py | 2 +- 2 files changed, 6 insertions(+), 133 deletions(-) diff --git a/mergekit/architecture/decoder_only.py b/mergekit/architecture/decoder_only.py index 488a4045..671a2d1c 100644 --- a/mergekit/architecture/decoder_only.py +++ b/mergekit/architecture/decoder_only.py @@ -23,138 +23,11 @@ from typing_extensions import Literal import mergekit._data.architectures -from mergekit.architecture.base import ModuleArchitecture - - -class WeightInfo(BaseModel, frozen=True): - """Information about an individual weight tensor in a model. - - Attributes: - name (str): - The name of the tensor representing the weight. - is_embed (bool): - Indicates whether the weight is for an embedding or language model head. - input_space (Optional[str]): - The name of the input space associated with the weight, if applicable. - output_space (Optional[str]): - The name of the output space associated with the weight, if applicable. - optional (bool): - Indicates whether the weight can be omitted from a model. - aliases (Optional[List[str]]): - List of alternative names for the weight, if applicable. - force_dtype (Optional[str]): - Mandatory dtype for the weight, if applicable. - """ - - name: str - is_embed: bool = False - input_space: Optional[str] = None - output_space: Optional[str] = None - optional: bool = False - aliases: Optional[Tuple[str, ...]] = None - force_dtype: Optional[str] = None - head_split: Literal[None, "input", "output"] = None - is_kq: Optional[bool] = False - - -class ProceduralSpaceInfo(BaseModel, frozen=True): - """Defines a procedural space computed from one or more other spaces. - - Currently only supports residual connections. - - Attributes: - name (str): The name of the space defined. - type (str): The type of procedural space. - inputs (List[str]): List of names of spaces used to define this space.""" - - name: str - type: Literal["residual"] - inputs: List[str] - - -class ArchitectureInfo(ABC): - @abstractmethod - def name(self) -> str: - """Return the name of the architecture.""" - ... - - @abstractmethod - def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return a list of all weights preceding the first layer.""" - ... - - @abstractmethod - def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return a list of all weights following the final layer.""" - ... - - @abstractmethod - def layer_weights( - self, index: int, config: PretrainedConfig - ) -> Optional[List[WeightInfo]]: - """Return a list of all weights associated with a given layer.""" - ... - - @abstractmethod - def sliceable(self) -> bool: - """ - Return True if the layers of this architecture can be meaningfully sliced. - """ - ... - - def num_layers_config_key(self) -> str: - """Key in config that represents number of layers""" - return "num_hidden_layers" - - def num_layers(self, config: PretrainedConfig) -> int: - """Return the number of layers in a model.""" - return getattr(config, self.num_layers_config_key()) - - def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return all weights associated with a model.""" - num_layers = self.num_layers(config) - res = list(self.pre_weights(config)) - for layer_idx in range(num_layers): - res.extend(self.layer_weights(layer_idx, config)) - res.extend(self.post_weights(config)) - return res - - def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: - """Return a list of all procedurally defined spaces in a model.""" - return [] - - def has_defined_spaces(self) -> bool: - """ - Return True if this architecture defines space information needed for - matching-based merge methods. - """ - return False - - -class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): - info: ArchitectureInfo - config: PretrainedConfig - - def name(self) -> str: - return self.info.name() - - def num_layers(self) -> int: - return self.info.num_layers(self.config) - - def pre_weights(self) -> List[WeightInfo]: - return self.info.pre_weights(self.config) - - def post_weights(self) -> List[WeightInfo]: - return self.info.post_weights(self.config) - - def layer_weights(self, index: int) -> List[WeightInfo]: - return self.info.layer_weights(index, self.config) - - def procedural_spaces(self) -> List[ProceduralSpaceInfo]: - return self.info.procedural_spaces(self.config) - - def all_weights(self) -> List[WeightInfo]: - return self.info.all_weights(self.config) +from mergekit.architecture.base import ( + ModuleArchitecture, + ProceduralSpaceInfo, + WeightInfo, +) class JSONLayerTemplates(BaseModel, frozen=True): diff --git a/mergekit/plan.py b/mergekit/plan.py index 81e3d476..8914042a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -222,7 +222,7 @@ def plan_tensor( print(f"output_weight: {repr(weight)} ({type(weight)})") tensor_task = tensor_merge_method.make_task( - output_weight=weight.model_dump(), + output_weight=weight, tensors=tensor_input_task, parameters=ImmutableMap(data=global_params), tensor_parameters=ImmutableMap( From 17cc91d54941b067f357ed13aa78977fe7713616 Mon Sep 17 00:00:00 2001 From: Charles Goddard Date: Sat, 5 Oct 2024 14:03:53 -0700 Subject: [PATCH 6/7] More fix --- mergekit/architecture/decoder_only.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/mergekit/architecture/decoder_only.py b/mergekit/architecture/decoder_only.py index 671a2d1c..fa36591b 100644 --- a/mergekit/architecture/decoder_only.py +++ b/mergekit/architecture/decoder_only.py @@ -15,12 +15,10 @@ import importlib.resources import string -from abc import ABC, abstractmethod from typing import ClassVar, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, Field from transformers import PretrainedConfig -from typing_extensions import Literal import mergekit._data.architectures from mergekit.architecture.base import ( @@ -73,30 +71,6 @@ def _template_substitution( return TemplateWithArithmetic(template).substitute(substitutions) -def _template_substitution( - template: str, num_layers: int, layer_idx: Optional[int] = None -) -> str: - if "{" not in template: - return template - - substitutions = { - "num_layers": num_layers, - "num_layers+1": num_layers + 1, - "num_layers-1": num_layers - 1, - } - - if layer_idx is not None: - substitutions.update( - { - "layer_index": layer_idx, - "layer_index+1": layer_idx + 1, - "layer_index-1": layer_idx - 1, - } - ) - - return TemplateWithArithmetic(template).substitute(substitutions) - - class JsonArchitectureInfo(ModuleArchitecture, BaseModel, frozen=True): definition: JSONArchitectureDefinition From 56054b1d7065c6c0e2eba44c992b974e60181c10 Mon Sep 17 00:00:00 2001 From: Nottlespike <151680919+Nottlespike@users.noreply.github.com> Date: Fri, 11 Oct 2024 20:45:49 -0700 Subject: [PATCH 7/7] Update moe.md --- docs/moe.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/moe.md b/docs/moe.md index 1d62d4a3..b5e90134 100644 --- a/docs/moe.md +++ b/docs/moe.md @@ -49,7 +49,7 @@ An appropriate architecture will be inferred based on the input models and prese ```yml base_model: path/to/self_attn_donor -architecture: qwen +architecture: Qwen MoE # Needed if using the Qwen MoE architecture with Qwen2.5 # ... and so on ```