diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 51e8cf8b9..7dc442854 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -1,5 +1,5 @@ import inspect -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from compressed_tensors.quantization import ( @@ -12,7 +12,7 @@ update_offload_parameter, ) from loguru import logger -from pydantic import ConfigDict, PrivateAttr, model_validator +from pydantic import ConfigDict, PrivateAttr, field_validator, model_validator from torch.nn import Module from tqdm import tqdm @@ -27,6 +27,7 @@ from llmcompressor.modifiers.quantization.quantization import QuantizationMixin from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache +from llmcompressor.sentinel import Sentinel from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers @@ -96,8 +97,6 @@ class AWQModifier(Modifier, QuantizationMixin): - on_finalize - clear resolved mappings and captured activations - :param sequential_targets: list of module names to compress in - the same calibration pass :param mappings: list activation layers to smooth, and which layers to scale the output such that activations are smoothed. Each entry of the mapping list should be a list itself, in which the first @@ -116,11 +115,7 @@ class AWQModifier(Modifier, QuantizationMixin): and weights to determine the scaling factor """ - # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module - model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) - # User-provided vars (in addition to QuantizationMixin args) - sequential_targets: Union[str, List[str], None] = None mappings: Optional[List[AWQMapping]] = None offload_device: Optional[torch.device] = None duo_scaling: bool = True @@ -141,6 +136,20 @@ class AWQModifier(Modifier, QuantizationMixin): default_factory=dict ) + # deprecated + sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated") + + # Allow arbitrary types because AWQMapping has fields of type torch.nn.Module + model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True) + + @field_validator("sequential_targets", mode="before") + def validate_sequential_targets(cls, value: bool) -> bool: + if value is not Sentinel("deprecated"): + raise ValueError( + "Setting `sequential_targets` via modifiers is no longer supported, " + "Please use `oneshot(sequential_targets=...)`" + ) + @model_validator(mode="after") def validate_model_after(model: "AWQModifier") -> "AWQModifier": """ diff --git a/src/llmcompressor/modifiers/obcq/base.py b/src/llmcompressor/modifiers/obcq/base.py index 0dc69ae03..9595ba9c5 100644 --- a/src/llmcompressor/modifiers/obcq/base.py +++ b/src/llmcompressor/modifiers/obcq/base.py @@ -45,6 +45,10 @@ class SparseGPTModifier(SparsityModifierBase): - on_finalize - remove_hooks() + :param targets: list of module names to quantize if a scheme is provided. Defaults + to Linear layers + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target. Defaults to empty list. :param sparsity: Sparsity to compress model to :param sparsity_profile: Can be set to 'owl' to use Outlier Weighed Layerwise Sparsity (OWL), more information can be found @@ -62,12 +66,6 @@ class SparseGPTModifier(SparsityModifierBase): previously pruned model, defaults to False. :param offload_hessians: Set to True for decreased memory usage but increased runtime. - :param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__' - to compress every layer in the model. Alias for `targets` - :param targets: list of layer names to compress during OBCQ, or '__ALL__' - to compress every layer in the model. Alias for `sequential_targets` - :param ignore: optional list of module class names or submodule names to not - quantize even if they match a target. Defaults to empty list. """ # modifier arguments diff --git a/src/llmcompressor/modifiers/obcq/sgpt_base.py b/src/llmcompressor/modifiers/obcq/sgpt_base.py index ce41273f3..b36a16d1b 100644 --- a/src/llmcompressor/modifiers/obcq/sgpt_base.py +++ b/src/llmcompressor/modifiers/obcq/sgpt_base.py @@ -8,13 +8,14 @@ import torch from loguru import logger from pydantic import Field, PrivateAttr, field_validator, model_validator +from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers.modifier import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin +from llmcompressor.sentinel import Sentinel from llmcompressor.utils.pytorch.module import ( get_layers, - get_no_split_params, get_prunable_layers, match_targets, ) @@ -27,18 +28,14 @@ class SparsityModifierBase(Modifier): """ # modifier arguments + targets: Union[str, List[str]] = ["Linear"] + ignore: List[str] = Field(default_factory=list) sparsity: Optional[Union[float, List[float]]] sparsity_profile: Optional[str] = None mask_structure: str = "0:0" owl_m: Optional[int] = None owl_lmbda: Optional[float] = None - # data pipeline arguments - sequential_update: Optional[bool] = False # deprecated - sequential_targets: Union[str, List[str], None] = None - targets: Union[str, List[str]] = ["Linear"] - ignore: List[str] = Field(default_factory=list) - # private variables _prune_n: Optional[int] = PrivateAttr(default=None) _prune_m: Optional[int] = PrivateAttr(default=None) @@ -46,16 +43,26 @@ class SparsityModifierBase(Modifier): _target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict) _module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) + # deprecated + sequential_update: Union[Sentinel, Any] = Sentinel("deprecated") + sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated") + @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: - if not value: + if value is not Sentinel("deprecated"): warnings.warn( "`sequential_update=False` is no longer supported, setting " "sequential_update=True", DeprecationWarning, ) - return True + @field_validator("sequential_targets", mode="before") + def validate_sequential_targets(cls, value: bool) -> bool: + if value is not Sentinel("deprecated"): + raise ValueError( + "Setting `sequential_targets` via modifiers is no longer supported, " + "Please use `oneshot(sequential_targets=...)`" + ) @field_validator("sparsity_profile", mode="before") def validate_sparsity_profile(cls, value: Optional[str]) -> bool: @@ -109,12 +116,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool: :param state: session state storing input model and calibration data """ - model: torch.nn.Module = state.model + model: PreTrainedModel = state.model dataloader: torch.utils.data.DataLoader = state.data.calib # infer module and sequential targets - self.sequential_targets = self._infer_sequential_targets(model) - layers = get_layers(self.sequential_targets, model) + sequential_targets = model._get_no_split_modules("auto") + layers = get_layers(sequential_targets, model) self._target_layers = get_layers( self.targets, model ) # layers containing targets @@ -191,15 +198,6 @@ def on_end(self, state: State, event: Event, **kwargs): self.ended_ = True self.remove_hooks() - def _infer_sequential_targets( - self, model: torch.nn.Module - ) -> Union[str, List[str]]: - if self.sequential_targets is None: - return get_no_split_params(model) - if isinstance(self.sequential_targets, str): - return [self.sequential_targets] - return self.sequential_targets - def _infer_owl_layer_sparsity( self, model: torch.nn.Module, diff --git a/src/llmcompressor/modifiers/pruning/wanda/base.py b/src/llmcompressor/modifiers/pruning/wanda/base.py index de562f6d7..5d6673411 100644 --- a/src/llmcompressor/modifiers/pruning/wanda/base.py +++ b/src/llmcompressor/modifiers/pruning/wanda/base.py @@ -36,14 +36,15 @@ class WandaPruningModifier(SparsityModifierBase): Lifecycle: - on_initialize - register_hook(module, calibrate_module, "forward") - - run_sequential / run_layer_sequential / run_basic - - make_empty_row_scalars - - accumulate_row_scalars - on_sequential_batch_end - sparsify_weight - on_finalize - remove_hooks() + :param targets: list of module names to quantize if a scheme is provided. Defaults + to Linear layers + :param ignore: optional list of module class names or submodule names to not + quantize even if they match a target. Defaults to empty list. :param sparsity: Sparsity to compress model to :param sparsity_profile: Can be set to 'owl' to use Outlier Weighed Layerwise Sparsity (OWL), more information can be found @@ -53,12 +54,6 @@ class WandaPruningModifier(SparsityModifierBase): shape. Defaults to 0:0 which represents an unstructured mask. :param owl_m: Number of outliers to use for OWL :param owl_lmbda: Lambda value to use for OWL - :param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__' - to compress every layer in the model. Alias for `targets` - :param targets: list of layer names to compress during OBCQ, or '__ALL__' - to compress every layer in the model. Alias for `sequential_targets` - :param ignore: optional list of module class names or submodule names to not - quantize even if they match a target. Defaults to empty list. """ # private variables diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index 7ae61f3e2..147f9b514 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,6 +1,6 @@ import contextlib import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import torch from compressed_tensors.quantization import ( @@ -61,7 +61,7 @@ class GPTQModifier(Modifier, QuantizationMixin): Lifecycle: - on_initialize - - apply config to model + - apply quantization config to model - on_start - add activation calibration hooks - add gptq weight calibration hooks @@ -71,8 +71,6 @@ class GPTQModifier(Modifier, QuantizationMixin): - remove_hooks() - model.apply(freeze_module_quantization) - :param sequential_targets: list of layer names to compress during GPTQ, or - '__ALL__' to compress every layer in the model :param block_size: Used to determine number of columns to compress in one pass :param dampening_frac: Amount of dampening to apply to H, as a fraction of the diagonal norm @@ -83,7 +81,7 @@ class GPTQModifier(Modifier, QuantizationMixin): :param config_groups: dictionary specifying quantization schemes to apply to target modules. Modules not matching a scheme target will NOT be quantized. - :param targets: list of layer names to quantize if a scheme is provided. Defaults + :param targets: list of module names to quantize if a scheme is provided. Defaults to Linear layers :param ignore: optional list of module class names or submodule names to not quantize even if they match a target in config_groups. Defaults to empty list. @@ -106,8 +104,6 @@ class GPTQModifier(Modifier, QuantizationMixin): """ # gptq modifier arguments - sequential_update: bool = True # DEPRECATED - sequential_targets: Union[str, List[str], None] = None block_size: int = 128 dampening_frac: Optional[float] = 0.01 actorder: Optional[Union[ActivationOrdering, Sentinel]] = None @@ -118,16 +114,26 @@ class GPTQModifier(Modifier, QuantizationMixin): _hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict) _num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict) + # deprecated + sequential_update: Union[Sentinel, Any] = Sentinel("deprecated") + sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated") + @field_validator("sequential_update", mode="before") def validate_sequential_update(cls, value: bool) -> bool: - if not value: + if value is not Sentinel("deprecated"): warnings.warn( "`sequential_update=False` is no longer supported, setting " "sequential_update=True", DeprecationWarning, ) - return True + @field_validator("sequential_targets", mode="before") + def validate_sequential_targets(cls, value: bool) -> bool: + if value is not Sentinel("deprecated"): + raise ValueError( + "Setting `sequential_targets` via modifiers is no longer supported, " + "Please use `oneshot(sequential_targets=...)`" + ) def resolve_quantization_config(self) -> QuantizationConfig: config = super().resolve_quantization_config() diff --git a/src/llmcompressor/pipelines/layer_sequential/helpers.py b/src/llmcompressor/pipelines/layer_sequential/helpers.py index c8b9fcbd3..c7e177411 100644 --- a/src/llmcompressor/pipelines/layer_sequential/helpers.py +++ b/src/llmcompressor/pipelines/layer_sequential/helpers.py @@ -69,7 +69,7 @@ def capture_first_layer_intermediates( desc = "Preparing intermediates cache" for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)): batch = apply_pad_mask_to_batch(batch) if mask_padding else batch - batch = tensors_to_device(batch, model_device) + batch = tensors_to_device(batch, torch.device("cpu")) try: model(**batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index e5a608708..650d5a66a 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -5,7 +5,7 @@ from compressed_tensors.utils import disable_offloading, get_execution_device from torch.utils.data.dataloader import DataLoader -from llmcompressor.core import LifecycleCallbacks, active_session +from llmcompressor.core import LifecycleCallbacks from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.layer_sequential.helpers import ( @@ -17,7 +17,7 @@ from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pipelines.sequential.helpers import ( dispatch_for_sequential, - get_sequential_targets, + infer_sequential_targets, ) from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context @@ -56,15 +56,17 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ - session = active_session() + # prepare model for sequential onloading + dispatch_for_sequential(model) # prepare model for sequential onloading dispatch_for_sequential(model) model_device = get_execution_device(model) # find layers - modifiers = session.get_modifiers() - sequential_targets = get_sequential_targets(modifiers, model, dataset_args) + sequential_targets = infer_sequential_targets( + model, dataset_args.sequential_targets + ) layers = match_modules(model, sequential_targets) LifecycleCallbacks.calibration_epoch_start() diff --git a/src/llmcompressor/pipelines/sequential/helpers.py b/src/llmcompressor/pipelines/sequential/helpers.py index 4f562818a..b374ceb24 100644 --- a/src/llmcompressor/pipelines/sequential/helpers.py +++ b/src/llmcompressor/pipelines/sequential/helpers.py @@ -2,7 +2,7 @@ import inspect from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import Any, Dict, List, Optional, Set import torch from compressed_tensors.quantization import find_name_or_class_matches @@ -20,20 +20,15 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils.fx import HFTracer -from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.utils.helpers import calibration_forward_context, patch_attr -from llmcompressor.utils.pytorch.module import get_no_split_params from .ast_helpers import autowrap_forwards -if TYPE_CHECKING: - from llmcompressor.args.dataset_arguments import DatasetArguments - __all__ = [ "trace_subgraphs", "Subgraph", - "get_sequential_targets", + "infer_sequential_targets", "dispatch_for_sequential", ] @@ -428,59 +423,21 @@ def match_modules(model: Module, target_names: List[str]) -> Set[Module]: ) -def get_sequential_targets( - modifiers: List[Modifier], model: PreTrainedModel, args: "DatasetArguments" +def infer_sequential_targets( + model: PreTrainedModel, sequential_targets: Optional[List[str]] ) -> List[str]: """ - Infer sequential targets from modifiers list and dataset args + Infer sequential targets dataset args :param model: model being calibrated - :param modifiers: list of modifiers being applied during calibration - :param dataset_args: dataset arguments passed by user + :param sequential_targets: optional targets passed by user :return: list of sequential targets """ - modifier_targets = [ - (modifier, modifier.sequential_targets) - for modifier in modifiers - if getattr(modifier, "sequential_targets", None) is not None - ] - - # deprecation warning - if len(modifier_targets) >= 1: - logger.warning( - "Passing sequential targets through modifiers is deprecated, " - "please use `oneshot(sequential_targets=...)`" - ) + if not sequential_targets: + sequential_targets = model._get_no_split_modules("auto") - # cannot infer from multiple modifiers - if len(modifier_targets) >= 2: - types = [type(modifier) for modifier, _ in modifier_targets] - raise ValueError( - "Cannot infer sequential targets from multiple sequential modifiers " - f"({types})" - ) - - # resolve single modifier - if len(modifier_targets) == 1: - if args.sequential_targets is not None: - raise ValueError( - f"Got sequential targets from both {type(modifier_targets[0][0])} " - "and dataset arguments `sequential_targets`" - ) - - sequential_targets = modifier_targets[0][1] - - # if no modifiers, use data args - else: - sequential_targets = args.sequential_targets # may be `None` - - # validate and infer - if sequential_targets is None: - return get_no_split_params(model) - elif isinstance(sequential_targets, str): - return [sequential_targets] - else: - return sequential_targets + if len(sequential_targets) <= 0: + return ["Linear"] def add_line_numbers(text: str) -> str: diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 9a2b8f3c9..3495941a3 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -5,13 +5,13 @@ from torch.utils.data.dataloader import DataLoader from tqdm import tqdm -from llmcompressor.core import LifecycleCallbacks, active_session +from llmcompressor.core import LifecycleCallbacks from llmcompressor.modifiers.utils.hooks import HooksMixin from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pipelines.sequential.helpers import ( dispatch_for_sequential, - get_sequential_targets, + infer_sequential_targets, trace_subgraphs, ) from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context @@ -50,15 +50,17 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ - session = active_session() + # prepare model for sequential onloading + dispatch_for_sequential(model) # prepare model for sequential onloading dispatch_for_sequential(model) model_device = get_execution_device(model) # prepare to trace subgraphs - modifiers = session.get_modifiers() - sequential_targets = get_sequential_targets(modifiers, model, dataset_args) + sequential_targets = infer_sequential_targets( + model, dataset_args.sequential_targets + ) ignore = dataset_args.tracing_ignore # trace subgraphs diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index bb70bf712..bb1970c30 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -7,8 +7,9 @@ import transformers from transformers import AutoProcessor, PreTrainedModel -from llmcompressor.utils.pytorch.module import get_no_split_params -from llmcompressor.pipelines.sequential.helpers import trace_subgraphs, Subgraph +from llmcompressor.pipelines.sequential.helpers import ( + infer_sequential_targets, trace_subgraphs, Subgraph +) from llmcompressor.transformers import TextGenerationDataset from llmcompressor.args import DatasetArguments @@ -86,10 +87,9 @@ def trace( print("Loaded sample data") # infer sequential targets - if sequential_targets is None: - sequential_targets = get_no_split_params(model) - if isinstance(sequential_targets, str): - sequential_targets = [sequential_targets] + sequential_targets = infer_sequential_targets( + model, sequential_targets + ) # Attempt trace print( diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 835493fa3..dc19185a5 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -12,7 +12,6 @@ from packaging import version from torch.nn import Linear, Module, Parameter from torch.nn.modules.conv import _ConvNd -from transformers import PreTrainedModel from llmcompressor.core import ModelParameterizedLayer from llmcompressor.utils.fsdp.context import ( @@ -60,7 +59,6 @@ "qat_active", "get_layers_params", "get_matching_layer", - "get_no_split_params", "get_layer_by_name", ] @@ -316,25 +314,6 @@ def get_matching_layer( return match -def get_no_split_params(model: PreTrainedModel) -> Union[str, List[str]]: - """ - Get list of module classes that shouldn't be split when sharding. For - Hugging Face Transformer models, this is the decoder layer type. For other - types of models, this just returns all module names. - - :return: list of class names that shouldn't be split - """ - # importing here to avoid circular import - from llmcompressor.utils.fsdp.helpers import maybe_get_wrapped - - model = maybe_get_wrapped(model) - no_split_modules = model._get_no_split_modules("auto") - if len(no_split_modules) <= 0: - return ALL_TARGET - - return no_split_modules - - # https://discuss.pytorch.org/t/how-to-access-to-a-layer-by-module-name/83797/8 def get_layer_by_name(layer_name: str, module: Module) -> Module: """ diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py b/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py deleted file mode 100644 index abe74da19..000000000 --- a/tests/llmcompressor/transformers/obcq/test_obcq_infer_targets.py +++ /dev/null @@ -1,15 +0,0 @@ -import pytest -from accelerate import init_empty_weights -from transformers import AutoModelForCausalLM - -from llmcompressor.modifiers.obcq import SparseGPTModifier - - -@pytest.mark.integration -def test_infer_targets(): - modifier = SparseGPTModifier(sparsity=0.0) - with init_empty_weights(): - model = AutoModelForCausalLM.from_pretrained("nm-testing/llama2.c-stories15M") - - inferred = modifier._infer_sequential_targets(model) - assert inferred == ["LlamaDecoderLayer"] diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py index a1de2533d..dea55d927 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_owl.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_owl.py @@ -28,7 +28,7 @@ def test_infer_owl_layer_sparsity(): ) dataloader = format_calibration_data(dataset) - sequential_targets = modifier._infer_sequential_targets(model) + sequential_targets = model._get_no_split_modules("auto") layers = get_layers(sequential_targets, model) sparsities = modifier._infer_owl_layer_sparsity(model, layers, dataloader) assert sparsities.keys() == layers.keys() diff --git a/tests/llmcompressor/transformers/tracing/test_models.py b/tests/llmcompressor/transformers/tracing/test_models.py index fba0ffe56..5603994a9 100644 --- a/tests/llmcompressor/transformers/tracing/test_models.py +++ b/tests/llmcompressor/transformers/tracing/test_models.py @@ -14,9 +14,11 @@ WhisperForConditionalGeneration, ) -from llmcompressor.pipelines.sequential.helpers import match_modules +from llmcompressor.pipelines.sequential.helpers import ( + infer_sequential_targets, + match_modules, +) from llmcompressor.transformers.tracing.debug import trace -from llmcompressor.utils.pytorch.module import get_no_split_params @pytest.mark.skipif( @@ -143,11 +145,7 @@ def test_model_trace(model_id, model_class, targets, modality, backends): def get_target_modules(model, sequential_targets): - if sequential_targets is None: - sequential_targets = get_no_split_params(model) - if isinstance(sequential_targets, str): - sequential_targets = [sequential_targets] - + sequential_targets = infer_sequential_targets(model, sequential_targets) return match_modules(model, sequential_targets)