diff --git a/docs/moe.md b/docs/moe.md index 1d62d4a3..b5e90134 100644 --- a/docs/moe.md +++ b/docs/moe.md @@ -49,7 +49,7 @@ An appropriate architecture will be inferred based on the input models and prese ```yml base_model: path/to/self_attn_donor -architecture: qwen +architecture: Qwen MoE # Needed if using the Qwen MoE architecture with Qwen2.5 # ... and so on ``` diff --git a/mergekit/architecture/__init__.py b/mergekit/architecture/__init__.py new file mode 100644 index 00000000..d9398fa9 --- /dev/null +++ b/mergekit/architecture/__init__.py @@ -0,0 +1,49 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + + +from transformers import PretrainedConfig + +from mergekit.architecture.base import ( + ModelArchitecture, + ModuleArchitecture, + ModuleConfiguredArchitecture, + ModuleDefinition, + WeightInfo, +) +from mergekit.architecture.decoder_only import get_decoder_only_arch + + +def get_architecture_info(config: PretrainedConfig) -> ModelArchitecture: + if len(config.architectures) != 1: + raise RuntimeError("More than one architecture in config?") + arch_name = config.architectures[0] + + if decoder := get_decoder_only_arch(arch_name, config=config): + return ModelArchitecture( + modules={"decoder": ModuleDefinition(architecture=decoder)} + ) + + raise RuntimeError(f"Unsupported architecture {arch_name}") + + +__all__ = [ + "ModelArchitecture", + "ModuleArchitecture", + "ModuleDefinition", + "ModuleConfiguredArchitecture", + "WeightInfo", + "get_architecture_info", +] diff --git a/mergekit/architecture/base.py b/mergekit/architecture/base.py new file mode 100644 index 00000000..4ac3ef63 --- /dev/null +++ b/mergekit/architecture/base.py @@ -0,0 +1,185 @@ +# Copyright (C) 2024 Charles O. Goddard +# +# This software is free software: you can redistribute it and/or +# modify it under the terms of the GNU Lesser General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# This software is distributed in the hope that it will be useful, but +# WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with this program. If not, see http://www.gnu.org/licenses/. + +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Tuple + +from pydantic import BaseModel +from transformers import PretrainedConfig +from typing_extensions import Literal + + +class WeightInfo(BaseModel, frozen=True): + """Information about an individual weight tensor in a model. + + Attributes: + name (str): + The name of the tensor representing the weight. + is_embed (bool): + Indicates whether the weight is for an embedding or language model head. + input_space (Optional[str]): + The name of the input space associated with the weight, if applicable. + output_space (Optional[str]): + The name of the output space associated with the weight, if applicable. + optional (bool): + Indicates whether the weight can be omitted from a model. + aliases (Optional[List[str]]): + List of alternative names for the weight, if applicable. + force_dtype (Optional[str]): + Mandatory dtype for the weight, if applicable. + """ + + name: str + is_embed: bool = False + input_space: Optional[str] = None + output_space: Optional[str] = None + optional: bool = False + aliases: Optional[Tuple[str, ...]] = None + force_dtype: Optional[str] = None + head_split: Literal[None, "input", "output"] = None + is_kq: Optional[bool] = False + + +class ProceduralSpaceInfo(BaseModel, frozen=True): + """Defines a procedural space computed from one or more other spaces. + + Currently only supports residual connections. + + Attributes: + name (str): The name of the space defined. + type (str): The type of procedural space. + inputs (List[str]): List of names of spaces used to define this space.""" + + name: str + type: Literal["residual"] + inputs: List[str] + + +def _prefix_weight(weight: WeightInfo, prefix: Optional[str] = None) -> WeightInfo: + if prefix is None: + return weight + return WeightInfo( + name=prefix + weight.name, + **weight.model_dump(exclude={"name"}), + ) + + +class ModuleArchitecture(ABC): + @abstractmethod + def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """Return a list of all weights preceding the first layer.""" + ... + + @abstractmethod + def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """Return a list of all weights following the final layer.""" + ... + + @abstractmethod + def layer_weights( + self, index: int, config: PretrainedConfig + ) -> Optional[List[WeightInfo]]: + """Return a list of all weights associated with a given layer.""" + ... + + @abstractmethod + def sliceable(self) -> bool: + """ + Return True if the layers of this architecture can be meaningfully sliced. + """ + ... + + def num_layers_config_key(self) -> str: + """Key in config that represents number of layers""" + return "num_hidden_layers" + + def num_layers(self, config: PretrainedConfig) -> int: + """Return the number of layers in a model.""" + return getattr(config, self.num_layers_config_key()) + + def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + """Return all weights associated with a model.""" + num_layers = self.num_layers(config) + res = list(self.pre_weights(config)) + for layer_idx in range(num_layers): + res.extend(self.layer_weights(layer_idx, config)) + res.extend(self.post_weights(config)) + return res + + def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: + """Return a list of all procedurally defined spaces in a model.""" + return [] + + def has_defined_spaces(self) -> bool: + """ + Return True if this architecture defines space information needed for + matching-based merge methods. + """ + return False + + +class ModuleConfiguredArchitecture( + BaseModel, frozen=True, arbitrary_types_allowed=True +): + info: ModuleArchitecture + config: PretrainedConfig + weight_prefix: Optional[str] = None + + def num_layers(self) -> int: + return self.info.num_layers(self.config) + + def pre_weights(self) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.pre_weights(self.config) + ] + + def post_weights(self) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.post_weights(self.config) + ] + + def layer_weights(self, index: int) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.layer_weights(index, self.config) + ] + + def procedural_spaces(self) -> List[ProceduralSpaceInfo]: + return self.info.procedural_spaces(self.config) + + def all_weights(self) -> List[WeightInfo]: + return [ + _prefix_weight(w, self.weight_prefix) + for w in self.info.all_weights(self.config) + ] + + +class ModuleDefinition(BaseModel, frozen=True, arbitrary_types_allowed=True): + architecture: ModuleArchitecture + weight_prefix: Optional[str] = None + subfolder: Optional[str] = None + + +class ModelArchitecture(BaseModel, frozen=True): + modules: Dict[str, ModuleDefinition] + + def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: + res = [] + for module in self.modules.values(): + for weight_info in module.architecture.all_weights(config=config): + res.append(_prefix_weight(weight_info, module.weight_prefix)) + return res diff --git a/mergekit/architecture.py b/mergekit/architecture/decoder_only.py similarity index 61% rename from mergekit/architecture.py rename to mergekit/architecture/decoder_only.py index 4c7b4625..fa36591b 100644 --- a/mergekit/architecture.py +++ b/mergekit/architecture/decoder_only.py @@ -15,145 +15,17 @@ import importlib.resources import string -from abc import ABC, abstractmethod from typing import ClassVar, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, Field from transformers import PretrainedConfig -from typing_extensions import Literal import mergekit._data.architectures - - -class WeightInfo(BaseModel, frozen=True): - """Information about an individual weight tensor in a model. - - Attributes: - name (str): - The name of the tensor representing the weight. - is_embed (bool): - Indicates whether the weight is for an embedding or language model head. - input_space (Optional[str]): - The name of the input space associated with the weight, if applicable. - output_space (Optional[str]): - The name of the output space associated with the weight, if applicable. - optional (bool): - Indicates whether the weight can be omitted from a model. - aliases (Optional[List[str]]): - List of alternative names for the weight, if applicable. - force_dtype (Optional[str]): - Mandatory dtype for the weight, if applicable. - """ - - name: str - is_embed: bool = False - input_space: Optional[str] = None - output_space: Optional[str] = None - optional: bool = False - aliases: Optional[Tuple[str, ...]] = None - force_dtype: Optional[str] = None - head_split: Literal[None, "input", "output"] = None - is_kq: Optional[bool] = False - - -class ProceduralSpaceInfo(BaseModel, frozen=True): - """Defines a procedural space computed from one or more other spaces. - - Currently only supports residual connections. - - Attributes: - name (str): The name of the space defined. - type (str): The type of procedural space. - inputs (List[str]): List of names of spaces used to define this space.""" - - name: str - type: Literal["residual"] - inputs: List[str] - - -class ArchitectureInfo(ABC): - @abstractmethod - def name(self) -> str: - """Return the name of the architecture.""" - ... - - @abstractmethod - def pre_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return a list of all weights preceding the first layer.""" - ... - - @abstractmethod - def post_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return a list of all weights following the final layer.""" - ... - - @abstractmethod - def layer_weights( - self, index: int, config: PretrainedConfig - ) -> Optional[List[WeightInfo]]: - """Return a list of all weights associated with a given layer.""" - ... - - @abstractmethod - def sliceable(self) -> bool: - """ - Return True if the layers of this architecture can be meaningfully sliced. - """ - ... - - def num_layers_config_key(self) -> str: - """Key in config that represents number of layers""" - return "num_hidden_layers" - - def num_layers(self, config: PretrainedConfig) -> int: - """Return the number of layers in a model.""" - return getattr(config, self.num_layers_config_key()) - - def all_weights(self, config: PretrainedConfig) -> List[WeightInfo]: - """Return all weights associated with a model.""" - num_layers = self.num_layers(config) - res = list(self.pre_weights(config)) - for layer_idx in range(num_layers): - res.extend(self.layer_weights(layer_idx, config)) - res.extend(self.post_weights(config)) - return res - - def procedural_spaces(self, config: PretrainedConfig) -> List[ProceduralSpaceInfo]: - """Return a list of all procedurally defined spaces in a model.""" - return [] - - def has_defined_spaces(self) -> bool: - """ - Return True if this architecture defines space information needed for - matching-based merge methods. - """ - return False - - -class ConfiguredArchitectureInfo(BaseModel, frozen=True, arbitrary_types_allowed=True): - info: ArchitectureInfo - config: PretrainedConfig - - def name(self) -> str: - return self.info.name() - - def num_layers(self) -> int: - return self.info.num_layers(self.config) - - def pre_weights(self) -> List[WeightInfo]: - return self.info.pre_weights(self.config) - - def post_weights(self) -> List[WeightInfo]: - return self.info.post_weights(self.config) - - def layer_weights(self, index: int) -> List[WeightInfo]: - return self.info.layer_weights(index, self.config) - - def procedural_spaces(self) -> List[ProceduralSpaceInfo]: - return self.info.procedural_spaces(self.config) - - def all_weights(self) -> List[WeightInfo]: - return self.info.all_weights(self.config) +from mergekit.architecture.base import ( + ModuleArchitecture, + ProceduralSpaceInfo, + WeightInfo, +) class JSONLayerTemplates(BaseModel, frozen=True): @@ -199,7 +71,7 @@ def _template_substitution( return TemplateWithArithmetic(template).substitute(substitutions) -class JsonArchitectureInfo(ArchitectureInfo, BaseModel, frozen=True): +class JsonArchitectureInfo(ModuleArchitecture, BaseModel, frozen=True): definition: JSONArchitectureDefinition def _substitute( @@ -279,7 +151,7 @@ def num_layers_config_key(self) -> str: return self.definition.num_layers_config_key -class MixtralTensorNames(ArchitectureInfo, BaseModel): +class MixtralTensorNames(ModuleArchitecture, BaseModel): ARCHITECTURE_NAME: ClassVar[str] = "MixtralForCausalLM" num_local_experts: int @@ -355,12 +227,9 @@ def _load_all_architectures() -> ( QWEN2_INFO = _load_json_arch("qwen2.json") -def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: - if len(config.architectures) != 1: - raise RuntimeError("More than one architecture in config?") - - arch_name = config.architectures[0] - +def get_decoder_only_arch( + arch_name: str, config: PretrainedConfig +) -> Optional[ModuleArchitecture]: if arch_name == MixtralTensorNames.ARCHITECTURE_NAME: return MixtralTensorNames.from_config(config) @@ -374,7 +243,3 @@ def get_architecture_info(config: PretrainedConfig) -> ArchitectureInfo: for c in candidates: if c.definition.expected_model_type == config.model_type: return c - - raise RuntimeError( - f"Unsupported model_type {config.model_type} for architecture {arch_name}" - ) diff --git a/mergekit/config.py b/mergekit/config.py index 5c79de7c..6edb579e 100644 --- a/mergekit/config.py +++ b/mergekit/config.py @@ -82,11 +82,24 @@ class OutputSliceDefinition(BaseModel): parameters: Optional[Dict[str, ParameterSetting]] = None -class MergeConfiguration(BaseModel): - merge_method: str +class OutputModuleDefinition(BaseModel): slices: Optional[List[OutputSliceDefinition]] = None models: Optional[List[InputModelDefinition]] = None parameters: Optional[Dict[str, ParameterSetting]] = None + + @model_validator(mode="after") + def validate_inputs(self): + if ((not self.slices) and (not self.models)) or (self.slices and self.models): + raise RuntimeError("Must specify either output slices or models to merge") + return self + + +class MergeConfiguration(BaseModel): + modules: Optional[Dict[str, OutputModuleDefinition]] = None + slices: Optional[List[OutputSliceDefinition]] = None + models: Optional[List[InputModelDefinition]] = None + + merge_method: str base_model: Optional[ModelReference] = None dtype: Optional[str] = None tokenizer_source: Union[ @@ -95,6 +108,7 @@ class MergeConfiguration(BaseModel): tokenizer: Optional[TokenizerConfig] = None chat_template: Optional[str] = None out_dtype: Optional[str] = None + parameters: Optional[Dict[str, ParameterSetting]] = None def referenced_models(self) -> List[ModelReference]: models = set() @@ -107,12 +121,31 @@ def referenced_models(self) -> List[ModelReference]: for s in self.slices: for src in s.sources: models.add(src.model) + if self.modules: + for m in self.modules.values(): + if m.models: + for model_in in m.models: + models.add(model_in.model) + if m.slices: + for s in m.slices: + for src in s.sources: + models.add(src.model) return list(models) @model_validator(mode="after") def validate_inputs(self): - if ((not self.slices) and (not self.models)) or (self.slices and self.models): - raise RuntimeError("Must specify either output slices or models to merge") + set_ct = 0 + if self.modules: + set_ct += 1 + if self.slices: + set_ct += 1 + if self.models: + set_ct += 1 + + if set_ct != 1: + raise RuntimeError( + "Exactly one of 'models', 'slices', or 'modules' must be present" + ) return self @model_validator(mode="after") @@ -133,6 +166,7 @@ class ConfigReader(BaseModel): t: float tensor_name: Optional[str] = None slice_out: Optional[OutputSliceDefinition] = None + module: Optional[OutputModuleDefinition] = None @property def base_model(self) -> Optional[ModelReference]: @@ -149,6 +183,7 @@ def for_out_slice(self, slice: OutputSliceDefinition) -> "ConfigReader": t=self.t, tensor_name=self.tensor_name, slice_out=slice, + module=self.module, ) def for_tensor(self, tensor_name: str) -> "ConfigReader": @@ -157,6 +192,7 @@ def for_tensor(self, tensor_name: str) -> "ConfigReader": t=self.t, tensor_name=tensor_name, slice_out=self.slice_out, + module=self.module, ) def with_t(self, t: float) -> "ConfigReader": @@ -165,6 +201,16 @@ def with_t(self, t: float) -> "ConfigReader": t=t, tensor_name=self.tensor_name, slice_out=self.slice_out, + module=self.module, + ) + + def for_module(self, module: OutputModuleDefinition) -> "ConfigReader": + return ConfigReader( + config=self.config, + t=self.t, + tensor_name=self.tensor_name, + slice_out=self.slice_out, + module=module, ) def parameter( @@ -191,6 +237,15 @@ def parameter( if value is not None: return value + if self.module and self.module.parameters and name in self.module.parameters: + value = evaluate_setting( + self.tensor_name, + self.module.parameters[name], + self.t, + ) + if value is not None: + return value + if self.config.parameters and name in self.config.parameters: value = evaluate_setting( self.tensor_name, diff --git a/mergekit/evo/actors.py b/mergekit/evo/actors.py index e107efe7..ff5c4986 100644 --- a/mergekit/evo/actors.py +++ b/mergekit/evo/actors.py @@ -207,7 +207,7 @@ def _maybe_init_model(self, config: MergeConfiguration): tokenizer_donor = self.genome.definition.base_model if tokenizer_donor is None: logging.warning( - f"Base model not set, using tokenizer from first model in genome" + "Base model not set, using tokenizer from first model in genome" ) tokenizer_donor = self.genome.definition.models[0] tok = transformers.AutoTokenizer.from_pretrained( diff --git a/mergekit/merge.py b/mergekit/merge.py index 60189f44..c69620f5 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -25,7 +25,7 @@ import transformers from mergekit._data import chat_templates -from mergekit.architecture import ArchitectureInfo, get_architecture_info +from mergekit.architecture import ModelArchitecture, get_architecture_info from mergekit.card import generate_card from mergekit.config import MergeConfiguration from mergekit.graph import Executor @@ -231,7 +231,7 @@ def _copy_tokenizer( def _model_out_config( config: MergeConfiguration, - arch_info: ArchitectureInfo, + arch_info: ModelArchitecture, trust_remote_code: bool = False, ) -> transformers.PretrainedConfig: """Return a configuration for the resulting model.""" @@ -244,19 +244,33 @@ def _model_out_config( elif config.dtype: res.torch_dtype = config.dtype - if config.slices: - try: - num_layers = sum( + module_layers = {} + for module_name in arch_info.modules: + if config.modules and module_name in config.modules: + module_def = config.modules.get(module_name) + module_layers[module_name] = sum( s.sources[0].layer_range[1] - s.sources[0].layer_range[0] - for s in config.slices + for s in module_def.slices ) - setattr(res, arch_info.num_layers_config_key(), num_layers) - except Exception as e: - logging.warning( - "Unable to set number of layers in output config - you may need to manually correct it.", - exc_info=e, + elif config.slices: + module_layers[module_name] = sum( + s.sources[0].layer_range[1] - s.sources[0].layer_range[0] + for s in config.slices ) + if module_layers: + for module_name in module_layers: + try: + module_info = arch_info.modules[module_name] + cfg_key = module_info.architecture.num_layers_config_key() + setattr(res, cfg_key, module_layers[module_name]) + except Exception as e: + logging.warning( + f"Unable to set number of layers for module {module_name} in output config " + "- you may need to manually correct it.", + exc_info=e, + ) + return res diff --git a/mergekit/plan.py b/mergekit/plan.py index bdcd7004..8914042a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -19,8 +19,8 @@ from mergekit import merge_methods from mergekit.architecture import ( - ArchitectureInfo, - ConfiguredArchitectureInfo, + ModelArchitecture, + ModuleConfiguredArchitecture, WeightInfo, ) from mergekit.common import ImmutableMap, ModelReference @@ -28,6 +28,7 @@ ConfigReader, InputSliceDefinition, MergeConfiguration, + OutputModuleDefinition, OutputSliceDefinition, ) from mergekit.graph import Task @@ -46,18 +47,18 @@ class MergePlanner: config: MergeConfiguration - arch_info: ArchitectureInfo + arch_info: ModelArchitecture options: MergeOptions out_model_config: Any _method: MergeMethod _tensors: List[Tuple[WeightInfo, Task]] - _current_layers: int = 0 + _current_module_layers: int = 0 _tokenizer_task: Optional[BuildTokenizer] = None def __init__( self, config: MergeConfiguration, - arch_info: ArchitectureInfo, + arch_info: ModelArchitecture, options: MergeOptions, out_model_config: Any, ): @@ -66,6 +67,7 @@ def __init__( self.options = options self.out_model_config = out_model_config self._method = merge_methods.get(config.merge_method) + self._tensors = [] token_cfg = {} tokenizer_source = config.tokenizer_source @@ -82,50 +84,80 @@ def __init__( ) @lru_cache - def model_arch_info(self, model: ModelReference): - return ConfiguredArchitectureInfo( - info=self.arch_info, + def _model_module_arch(self, model: ModelReference, module_name: str): + module_def = self.arch_info.modules[module_name] + return ModuleConfiguredArchitecture( + info=module_def.architecture, config=model.config(trust_remote_code=self.options.trust_remote_code), + weight_prefix=module_def.weight_prefix, ) def normalize_config(self): base_model = self.config.base_model - # if models to merge are specified instead of output slices, compute them + # models -> modules.models if self.config.models: - if self.config.slices: - raise RuntimeError( - "Must specify either models to merge or output slices" + self.config.modules = {} + for module_name in self.arch_info.modules: + self.config.modules[module_name] = OutputModuleDefinition( + name=module_name, models=self.config.models ) + self.config.models = None - slices_in = [] - base_included = False - - for model_in in self.config.models: - if base_model and model_in.model == base_model: - base_included = True - - model_info = self.model_arch_info(model_in.model) - slices_in.append( - InputSliceDefinition( - layer_range=[0, model_info.num_layers()], - model=model_in.model, - parameters=model_in.parameters, - ) + # slices -> modules.slices + if self.config.slices: + if len(self.arch_info.modules) != 1: + raise RuntimeError( + "Model has multiple modules, must use modules: config syntax" ) + module_name = list(self.arch_info.modules.keys())[0] + self.config.modules = { + module_name: OutputModuleDefinition(slices=self.config.slices) + } + self.config.slices = None + + # modules.models -> modules.slices + for module_name in self.config.modules: + module_out = self.config.modules[module_name] + module_arch = self.arch_info.modules[module_name].architecture + + if module_out.models: + slices_in = [] + base_included = False + + for model_in in module_out.models: + if base_model and model_in.model == base_model: + base_included = True + + model_cfg = model_in.model.config( + trust_remote_code=self.options.trust_remote_code + ) + num_layers = module_arch.num_layers(model_cfg) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=model_in.model, + parameters=model_in.parameters, + ) + ) - if base_model and not base_included: - logging.info("Base model specified but not in input models - adding") - base_info = self.model_arch_info(base_model) - slices_in.append( - InputSliceDefinition( - layer_range=[0, base_info.num_layers()], - model=base_model, + if base_model and not base_included: + logging.info( + "Base model specified but not in input models - adding" + ) + base_cfg = base_model.config( + trust_remote_code=self.options.trust_remote_code + ) + num_layers = module_arch.num_layers(base_cfg) + slices_in.append( + InputSliceDefinition( + layer_range=[0, num_layers], + model=base_model, + ) ) - ) - self.config.slices = [OutputSliceDefinition(sources=slices_in)] - self.config.models = None + module_out.slices = [OutputSliceDefinition(sources=slices_in)] + module_out.models = None def plan_tensor( self, @@ -188,6 +220,7 @@ def plan_tensor( base_model=base_model, ) + print(f"output_weight: {repr(weight)} ({type(weight)})") tensor_task = tensor_merge_method.make_task( output_weight=weight, tensors=tensor_input_task, @@ -207,13 +240,15 @@ def plan_layer( layer_offset: int, t: float, cfg_reader: ConfigReader, + module_name: str, ): - weights_out: List[WeightInfo] = self.arch_info.layer_weights( - index=self._current_layers, + module_arch_def = self.arch_info.modules[module_name] + weights_out: List[WeightInfo] = module_arch_def.architecture.layer_weights( + index=self._current_module_layers, config=self.out_model_config, ) weights_in: List[List[WeightInfo]] = [ - self.model_arch_info(s.model).layer_weights( + self._model_module_arch(s.model, module_name).layer_weights( index=s.layer_range[0] + layer_offset ) for s in sources @@ -227,9 +262,14 @@ def plan_layer( cfg_reader=cfg_reader.with_t(t), ) - self._current_layers += 1 + self._current_module_layers += 1 - def plan_slice(self, definition: OutputSliceDefinition): + def plan_slice( + self, + definition: OutputSliceDefinition, + module_def: OutputModuleDefinition, + module_name: str, + ): slice_lengths = [ s.layer_range[1] - s.layer_range[0] for s in definition.sources ] @@ -239,7 +279,9 @@ def plan_slice(self, definition: OutputSliceDefinition): ) num_layers = slice_lengths[0] - cfg_reader = ConfigReader(config=self.config, slice_out=definition, t=0) + cfg_reader = ConfigReader( + config=self.config, slice_out=definition, t=0, module=module_def + ) for idx in range(num_layers): # compute t for interpolated gradients if num_layers > 1: @@ -252,6 +294,44 @@ def plan_slice(self, definition: OutputSliceDefinition): layer_offset=idx, t=t, cfg_reader=cfg_reader, + module_name=module_name, + ) + + def plan_module(self, module_name: str, definition: OutputModuleDefinition): + self._current_module_layers = 0 + + module_arch_def = self.arch_info.modules[module_name] + config_reader = ConfigReader(config=self.config, t=0, module=definition) + + for weight_info in module_arch_def.architecture.pre_weights( + self.out_model_config + ): + self.plan_tensor( + weight_info, + [weight_info] * len(definition.slices[0].sources), + [s.model for s in definition.slices[0].sources], + config_reader.for_tensor(tensor_name=weight_info.name).for_out_slice( + definition.slices[0] + ), + ) + + for out_slice in definition.slices: + self.plan_slice( + out_slice, + module_def=definition, + module_name=module_name, + ) + + for weight_info in module_arch_def.architecture.post_weights( + self.out_model_config + ): + self.plan_tensor( + weight_info, + [weight_info] * len(definition.slices[0].sources), + [s.model for s in definition.slices[-1].sources], + config_reader.for_tensor(tensor_name=weight_info.name).for_out_slice( + definition.slices[-1] + ), ) def plan_to_disk(self, out_path: str) -> List[Task]: @@ -298,31 +378,7 @@ def plan_in_memory(self) -> List[ReturnTensor]: def _plan(self): self.normalize_config() - self._tensors = [] - - for weight_info in self.arch_info.pre_weights(config=self.out_model_config): - self.plan_tensor( - weight_info, - [weight_info] * len(self.config.slices[0].sources), - [s.model for s in self.config.slices[0].sources], - ConfigReader( - config=self.config, - t=0, - tensor_name=weight_info.name, - ).for_out_slice(self.config.slices[0]), - ) + self._tasks = [] - for out_slice in self.config.slices: - self.plan_slice(out_slice) - - for weight_info in self.arch_info.post_weights(config=self.out_model_config): - self.plan_tensor( - weight_info, - [weight_info] * len(self.config.slices[-1].sources), - [s.model for s in self.config.slices[-1].sources], - ConfigReader( - config=self.config, - t=1, - tensor_name=weight_info.name, - ).for_out_slice(self.config.slices[-1]), - ) + for module_name in self.config.modules: + self.plan_module(module_name, self.config.modules[module_name]) diff --git a/mergekit/scripts/extract_lora.py b/mergekit/scripts/extract_lora.py index ff063232..69c010bb 100644 --- a/mergekit/scripts/extract_lora.py +++ b/mergekit/scripts/extract_lora.py @@ -8,10 +8,8 @@ import torch from peft.tuners.lora import QuantLinear from safetensors.torch import save_file -from torch.nn.functional import pad from tqdm import tqdm from transformers import AutoModelForCausalLM -from transformers.modeling_utils import PreTrainedModel from transformers.pytorch_utils import Conv1D from mergekit.card import generate_card_lora diff --git a/tests/test_chat_template.py b/tests/test_chat_template.py index af511a2b..2bd41cde 100644 --- a/tests/test_chat_template.py +++ b/tests/test_chat_template.py @@ -1,13 +1,25 @@ from typing import Optional -from common import run_and_check_merge -from test_basic_merges import model_b -from test_tokenizer import model_base +import pytest +from common import make_picollama, run_and_check_merge +from test_tokenizer import make_tokenizer from transformers import AutoTokenizer from mergekit.config import InputModelDefinition, MergeConfiguration +@pytest.fixture(scope="session") +def model_base(tmp_path_factory): + model_path = make_picollama(tmp_path_factory.mktemp("model_base"), vocab_size=64) + make_tokenizer(vocab_size=64, added_tokens=[]).save_pretrained(model_path) + return model_path + + +@pytest.fixture(scope="session") +def model_b(tmp_path_factory): + return make_picollama(tmp_path_factory.mktemp("model_b")) + + def check_chat_template(model_path: str, needle: Optional[str] = None): tokenizer = AutoTokenizer.from_pretrained(model_path) if needle is None: