Skip to content

Disable sequential_targets from modifiers #1559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions src/llmcompressor/modifiers/awq/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import (
Expand All @@ -12,7 +12,7 @@
update_offload_parameter,
)
from loguru import logger
from pydantic import ConfigDict, PrivateAttr, model_validator
from pydantic import ConfigDict, PrivateAttr, field_validator, model_validator
from torch.nn import Module
from tqdm import tqdm

Expand All @@ -27,6 +27,7 @@
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.sentinel import Sentinel
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
from llmcompressor.utils.helpers import calibration_forward_context
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers
Expand Down Expand Up @@ -96,8 +97,6 @@ class AWQModifier(Modifier, QuantizationMixin):
- on_finalize
- clear resolved mappings and captured activations

:param sequential_targets: list of module names to compress in
the same calibration pass
:param mappings: list activation layers to smooth, and which layers to
scale the output such that activations are smoothed.
Each entry of the mapping list should be a list itself, in which the first
Expand All @@ -116,11 +115,7 @@ class AWQModifier(Modifier, QuantizationMixin):
and weights to determine the scaling factor
"""

# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

# User-provided vars (in addition to QuantizationMixin args)
sequential_targets: Union[str, List[str], None] = None
mappings: Optional[List[AWQMapping]] = None
offload_device: Optional[torch.device] = None
duo_scaling: bool = True
Expand All @@ -141,6 +136,20 @@ class AWQModifier(Modifier, QuantizationMixin):
default_factory=dict
)

# deprecated
sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated")

# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)

@field_validator("sequential_targets", mode="before")
def validate_sequential_targets(cls, value: bool) -> bool:
if value is not Sentinel("deprecated"):
raise ValueError(
"Setting `sequential_targets` via modifiers is no longer supported, "
"Please use `oneshot(sequential_targets=...)`"
)

@model_validator(mode="after")
def validate_model_after(model: "AWQModifier") -> "AWQModifier":
"""
Expand Down
10 changes: 4 additions & 6 deletions src/llmcompressor/modifiers/obcq/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class SparseGPTModifier(SparsityModifierBase):
- on_finalize
- remove_hooks()

:param targets: list of module names to quantize if a scheme is provided. Defaults
to Linear layers
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
:param sparsity: Sparsity to compress model to
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
Layerwise Sparsity (OWL), more information can be found
Expand All @@ -62,12 +66,6 @@ class SparseGPTModifier(SparsityModifierBase):
previously pruned model, defaults to False.
:param offload_hessians: Set to True for decreased memory usage but increased
runtime.
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `targets`
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `sequential_targets`
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
"""

# modifier arguments
Expand Down
40 changes: 19 additions & 21 deletions src/llmcompressor/modifiers/obcq/sgpt_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
import torch
from loguru import logger
from pydantic import Field, PrivateAttr, field_validator, model_validator
from transformers import PreTrainedModel

from llmcompressor.core import Event, EventType, State
from llmcompressor.modifiers.modifier import Modifier
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.sentinel import Sentinel
from llmcompressor.utils.pytorch.module import (
get_layers,
get_no_split_params,
get_prunable_layers,
match_targets,
)
Expand All @@ -27,35 +28,41 @@ class SparsityModifierBase(Modifier):
"""

# modifier arguments
targets: Union[str, List[str]] = ["Linear"]
ignore: List[str] = Field(default_factory=list)
sparsity: Optional[Union[float, List[float]]]
sparsity_profile: Optional[str] = None
mask_structure: str = "0:0"
owl_m: Optional[int] = None
owl_lmbda: Optional[float] = None

# data pipeline arguments
sequential_update: Optional[bool] = False # deprecated
sequential_targets: Union[str, List[str], None] = None
targets: Union[str, List[str]] = ["Linear"]
ignore: List[str] = Field(default_factory=list)

# private variables
_prune_n: Optional[int] = PrivateAttr(default=None)
_prune_m: Optional[int] = PrivateAttr(default=None)
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
_target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)

# deprecated
sequential_update: Union[Sentinel, Any] = Sentinel("deprecated")
sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated")

@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
if value is not Sentinel("deprecated"):
warnings.warn(
"`sequential_update=False` is no longer supported, setting "
"sequential_update=True",
DeprecationWarning,
)

return True
@field_validator("sequential_targets", mode="before")
def validate_sequential_targets(cls, value: bool) -> bool:
if value is not Sentinel("deprecated"):
raise ValueError(
"Setting `sequential_targets` via modifiers is no longer supported, "
"Please use `oneshot(sequential_targets=...)`"
)

@field_validator("sparsity_profile", mode="before")
def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
Expand Down Expand Up @@ -109,12 +116,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool:

:param state: session state storing input model and calibration data
"""
model: torch.nn.Module = state.model
model: PreTrainedModel = state.model
dataloader: torch.utils.data.DataLoader = state.data.calib

# infer module and sequential targets
self.sequential_targets = self._infer_sequential_targets(model)
layers = get_layers(self.sequential_targets, model)
sequential_targets = model._get_no_split_modules("auto")
layers = get_layers(sequential_targets, model)
self._target_layers = get_layers(
self.targets, model
) # layers containing targets
Expand Down Expand Up @@ -191,15 +198,6 @@ def on_end(self, state: State, event: Event, **kwargs):
self.ended_ = True
self.remove_hooks()

def _infer_sequential_targets(
self, model: torch.nn.Module
) -> Union[str, List[str]]:
if self.sequential_targets is None:
return get_no_split_params(model)
if isinstance(self.sequential_targets, str):
return [self.sequential_targets]
return self.sequential_targets

def _infer_owl_layer_sparsity(
self,
model: torch.nn.Module,
Expand Down
13 changes: 4 additions & 9 deletions src/llmcompressor/modifiers/pruning/wanda/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ class WandaPruningModifier(SparsityModifierBase):
Lifecycle:
- on_initialize
- register_hook(module, calibrate_module, "forward")
- run_sequential / run_layer_sequential / run_basic
- make_empty_row_scalars
- accumulate_row_scalars
- on_sequential_batch_end
- sparsify_weight
- on_finalize
- remove_hooks()

:param targets: list of module names to quantize if a scheme is provided. Defaults
to Linear layers
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
:param sparsity: Sparsity to compress model to
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
Layerwise Sparsity (OWL), more information can be found
Expand All @@ -53,12 +54,6 @@ class WandaPruningModifier(SparsityModifierBase):
shape. Defaults to 0:0 which represents an unstructured mask.
:param owl_m: Number of outliers to use for OWL
:param owl_lmbda: Lambda value to use for OWL
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `targets`
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
to compress every layer in the model. Alias for `sequential_targets`
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target. Defaults to empty list.
"""

# private variables
Expand Down
24 changes: 15 additions & 9 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import torch
from compressed_tensors.quantization import (
Expand Down Expand Up @@ -61,7 +61,7 @@ class GPTQModifier(Modifier, QuantizationMixin):

Lifecycle:
- on_initialize
- apply config to model
- apply quantization config to model
- on_start
- add activation calibration hooks
- add gptq weight calibration hooks
Expand All @@ -71,8 +71,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
- remove_hooks()
- model.apply(freeze_module_quantization)

:param sequential_targets: list of layer names to compress during GPTQ, or
'__ALL__' to compress every layer in the model
:param block_size: Used to determine number of columns to compress in one pass
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
diagonal norm
Expand All @@ -83,7 +81,7 @@ class GPTQModifier(Modifier, QuantizationMixin):

:param config_groups: dictionary specifying quantization schemes to apply to target
modules. Modules not matching a scheme target will NOT be quantized.
:param targets: list of layer names to quantize if a scheme is provided. Defaults
:param targets: list of module names to quantize if a scheme is provided. Defaults
to Linear layers
:param ignore: optional list of module class names or submodule names to not
quantize even if they match a target in config_groups. Defaults to empty list.
Expand All @@ -106,8 +104,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
"""

# gptq modifier arguments
sequential_update: bool = True # DEPRECATED
sequential_targets: Union[str, List[str], None] = None
block_size: int = 128
dampening_frac: Optional[float] = 0.01
actorder: Optional[Union[ActivationOrdering, Sentinel]] = None
Expand All @@ -118,16 +114,26 @@ class GPTQModifier(Modifier, QuantizationMixin):
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)

# deprecated
sequential_update: Union[Sentinel, Any] = Sentinel("deprecated")
sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated")

@field_validator("sequential_update", mode="before")
def validate_sequential_update(cls, value: bool) -> bool:
if not value:
if value is not Sentinel("deprecated"):
warnings.warn(
"`sequential_update=False` is no longer supported, setting "
"sequential_update=True",
DeprecationWarning,
)

return True
@field_validator("sequential_targets", mode="before")
def validate_sequential_targets(cls, value: bool) -> bool:
if value is not Sentinel("deprecated"):
raise ValueError(
"Setting `sequential_targets` via modifiers is no longer supported, "
"Please use `oneshot(sequential_targets=...)`"
)

def resolve_quantization_config(self) -> QuantizationConfig:
config = super().resolve_quantization_config()
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/pipelines/layer_sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def capture_first_layer_intermediates(
desc = "Preparing intermediates cache"
for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)):
batch = apply_pad_mask_to_batch(batch) if mask_padding else batch
batch = tensors_to_device(batch, model_device)
batch = tensors_to_device(batch, torch.device("cpu"))

try:
model(**batch)
Expand Down
12 changes: 7 additions & 5 deletions src/llmcompressor/pipelines/layer_sequential/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from compressed_tensors.utils import disable_offloading, get_execution_device
from torch.utils.data.dataloader import DataLoader

from llmcompressor.core import LifecycleCallbacks, active_session
from llmcompressor.core import LifecycleCallbacks
from llmcompressor.modifiers.utils.hooks import HooksMixin
from llmcompressor.pipelines.cache import IntermediatesCache
from llmcompressor.pipelines.layer_sequential.helpers import (
Expand All @@ -17,7 +17,7 @@
from llmcompressor.pipelines.registry import CalibrationPipeline
from llmcompressor.pipelines.sequential.helpers import (
dispatch_for_sequential,
get_sequential_targets,
infer_sequential_targets,
)
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context

Expand Down Expand Up @@ -56,15 +56,17 @@ def __call__(
:param dataloader: loads data for calibration
:param dataset_args: dataset arguments relevant to pipelines
"""
session = active_session()
# prepare model for sequential onloading
dispatch_for_sequential(model)

# prepare model for sequential onloading
dispatch_for_sequential(model)
model_device = get_execution_device(model)

# find layers
modifiers = session.get_modifiers()
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
sequential_targets = infer_sequential_targets(
model, dataset_args.sequential_targets
)
layers = match_modules(model, sequential_targets)

LifecycleCallbacks.calibration_epoch_start()
Expand Down
Loading
Loading