From 3cb6e8d4d03170f539f3e3cf8c00a20ee20f1d9f Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Thu, 14 Aug 2025 14:28:10 -0700 Subject: [PATCH 1/2] [executorch] Add TorchAO wrapper config to allow filter_fn for quantize_ Pull Request resolved: https://github.com/pytorch/executorch/pull/13386 Fixing tests for stack that got reverted: https://github.com/pytorch/executorch/pull/13264 Changes: Support filter function in quantize_ function when using torchao quantize. Update unittests accordingly Use ComposableQuantizer if there are multiple quantizers and is of type torchao, for legacy quantizers use them directly with prepare_pt2e. Source transform modifies model inplace, so deep copy first to avoid modifying user provided model. ghstack-source-id: 303126052 @exported-using-ghexport Differential Revision: [D80206543](https://our.internmc.facebook.com/intern/diff/D80206543/) --- .../recipes/xnnpack_recipe_provider.py | 40 +++++--- .../xnnpack/recipes/xnnpack_recipe_types.py | 21 +++-- .../test/recipes/test_xnnpack_recipes.py | 94 +++++++++++-------- export/__init__.py | 9 +- export/recipe.py | 23 ++++- export/stages.py | 52 ++++++++-- export/tests/TARGETS | 2 +- export/tests/test_export_session.py | 10 +- export/tests/test_export_stages.py | 68 ++++++++++++-- 9 files changed, 230 insertions(+), 89 deletions(-) diff --git a/backends/xnnpack/recipes/xnnpack_recipe_provider.py b/backends/xnnpack/recipes/xnnpack_recipe_provider.py index 8fba58c12c3..436eb2db158 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_provider.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_provider.py @@ -25,6 +25,7 @@ get_xnnpack_executorch_backend_config, ) from executorch.export import ( + AOQuantizationConfig, BackendRecipeProvider, ExportRecipe, LoweringRecipe, @@ -57,31 +58,37 @@ def create_recipe( if recipe_type == XNNPackRecipeType.FP32: return self._build_fp32_recipe(recipe_type) - elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL: + elif recipe_type == XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL: return self._build_quantized_recipe( recipe_type, is_per_channel=True, is_dynamic=True ) - elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_CHANNEL: + elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL: return self._build_quantized_recipe( recipe_type, is_per_channel=True, is_dynamic=False ) - elif recipe_type == XNNPackRecipeType.INT8_STATIC_PER_TENSOR: + elif recipe_type == XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR: return self._build_quantized_recipe( recipe_type, is_per_channel=False, is_dynamic=False ) - elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL: - return self._build_int8da_intx_weight_recipe( + elif ( + recipe_type + == XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL + ): + return self._build_torchao_quantized_recipe( recipe_type=recipe_type, is_per_channel=True, weight_dtype=torch.int4, ) - elif recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR: + elif ( + recipe_type + == XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR + ): group_size = kwargs.get("group_size", 32) - return self._build_int8da_intx_weight_recipe( + return self._build_torchao_quantized_recipe( recipe_type=recipe_type, is_per_channel=False, weight_dtype=torch.int4, @@ -132,7 +139,7 @@ def _build_quantized_recipe( executorch_backend_config=get_xnnpack_executorch_backend_config(), ) - def _build_int8da_intx_weight_recipe( + def _build_torchao_quantized_recipe( self, recipe_type: RecipeType, is_per_channel: bool = True, @@ -141,17 +148,21 @@ def _build_int8da_intx_weight_recipe( ) -> ExportRecipe: if is_per_channel: weight_granularity = PerAxis(axis=0) + assert weight_dtype == torch.int4 or weight_dtype == torch.int8 else: weight_granularity = PerGroup(group_size=group_size) + assert weight_dtype == torch.int4 - config = Int8DynamicActivationIntxWeightConfig( - weight_dtype=weight_dtype, - weight_granularity=weight_granularity, + config = AOQuantizationConfig( + Int8DynamicActivationIntxWeightConfig( + weight_dtype=weight_dtype, + weight_granularity=weight_granularity, + ) ) quant_recipe = QuantizationRecipe( quantizers=None, - ao_base_config=[config], + ao_quantization_configs=[config], ) return ExportRecipe( @@ -162,7 +173,10 @@ def _build_int8da_intx_weight_recipe( ) def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: - if recipe_type == XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR: + if ( + recipe_type + == XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR + ): expected_keys = {"group_size"} unexpected = set(kwargs.keys()) - expected_keys if unexpected: diff --git a/backends/xnnpack/recipes/xnnpack_recipe_types.py b/backends/xnnpack/recipes/xnnpack_recipe_types.py index 5675c3a5ffa..61117b94502 100644 --- a/backends/xnnpack/recipes/xnnpack_recipe_types.py +++ b/backends/xnnpack/recipes/xnnpack_recipe_types.py @@ -13,19 +13,22 @@ class XNNPackRecipeType(RecipeType): """XNNPACK-specific recipe types""" FP32 = "fp32" + + ## PT2E-based quantization recipes # INT8 Dynamic Quantization - INT8_DYNAMIC_PER_CHANNEL = "int8_dynamic_per_channel" + PT2E_INT8_DYNAMIC_PER_CHANNEL = "pt2e_int8_dynamic_per_channel" + # INT8 Static Quantization, needs calibration dataset + PT2E_INT8_STATIC_PER_CHANNEL = "pt2e_int8_static_per_channel" + PT2E_INT8_STATIC_PER_TENSOR = "pt2e_int8_static_per_tensor" + + ## TorchAO-based quantization recipes # INT8 Dynamic Activations INT4 Weight Quantization, Axis = 0 - INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8da_int4w_per_channel" + TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL = ( + "torchao_int8da_int4w_per_channel" + ) # INT8 Dynamic Activations INT4 Weight Quantization, default group_size = 32 # can be overriden by group_size kwarg - INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8da_int4w_per_tensor" - # INT8 Static Activations INT4 Weight Quantization - INT8_STATIC_ACT_INT4_WEIGHT_PER_CHANNEL = "int8a_int4w_per_channel" - INT8_STATIC_ACT_INT4_WEIGHT_PER_TENSOR = "int8a_int44w_per_tensor" - # INT8 Static Quantization, needs calibration dataset - INT8_STATIC_PER_CHANNEL = "int8_static_per_channel" - INT8_STATIC_PER_TENSOR = "int8_static_per_tensor" + TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR = "torchao_int8da_int4w_per_tensor" @classmethod def get_backend_name(cls) -> str: diff --git a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py index 679743e42d3..565b71eab71 100644 --- a/backends/xnnpack/test/recipes/test_xnnpack_recipes.py +++ b/backends/xnnpack/test/recipes/test_xnnpack_recipes.py @@ -18,9 +18,10 @@ from executorch.examples.models.model_factory import EagerModelFactory from executorch.examples.xnnpack import MODEL_NAME_TO_OPTIONS, QuantType from executorch.exir.schema import DelegateCall, Program -from executorch.export import export, ExportRecipe, recipe_registry +from executorch.export import export, ExportRecipe, recipe_registry, StageType from torch import nn from torch.testing._internal.common_quantization import TestHelperModules +from torchao.quantization.utils import compute_error class TestXnnpackRecipes(unittest.TestCase): @@ -38,6 +39,29 @@ def check_fully_delegated(self, program: Program) -> None: self.assertEqual(len(instructions), 1) self.assertIsInstance(instructions[0].instr_args, DelegateCall) + # pyre-ignore + def _compare_eager_quantized_model_outputs( + self, session, example_inputs, atol: float + ) -> None: + """Utility to compare eager quantized model output with session output after xnnpack lowering""" + torch_export_stage_output = session.get_stage_artifacts()[ + StageType.TORCH_EXPORT + ] + eager_quantized_model = torch_export_stage_output.data["forward"].module() + output = session.run_method("forward", example_inputs[0])[0] + expected = eager_quantized_model(*example_inputs[0]) + Tester._assert_outputs_equal(output, expected, atol=atol) + + def _compare_eager_unquantized_model_outputs( + self, session, eager_unquantized_model, example_inputs, sqnr_threshold=20 + ): + """Utility to compare eager unquantized model output with session output using SQNR""" + quantized_output = session.run_method("forward", example_inputs[0])[0] + original_output = eager_unquantized_model(*example_inputs[0]) + error = compute_error(original_output, quantized_output) + print(f"{self._testMethodName} - SQNR: {error} dB") + self.assertTrue(error > sqnr_threshold) + def test_basic_recipe(self) -> None: m_eager = TestHelperModules.TwoLinearModule().eval() example_inputs = [(torch.randn(9, 8),)] @@ -46,18 +70,13 @@ def test_basic_recipe(self) -> None: example_inputs=example_inputs, export_recipe=ExportRecipe.get_recipe(XNNPackRecipeType.FP32), ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-3, - ) - ) + self._compare_eager_quantized_model_outputs(session, example_inputs, 1e-3) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_unquantized_model_outputs(session, m_eager, example_inputs) def test_int8_dynamic_quant_recipe(self) -> None: test_cases = [ - ExportRecipe.get_recipe(XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL), + ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL), ] for export_recipe in test_cases: @@ -70,19 +89,18 @@ def test_int8_dynamic_quant_recipe(self) -> None: example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-1, - ) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-1 ) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_unquantized_model_outputs( + session, m_eager, example_inputs + ) def test_int8_static_quant_recipe(self) -> None: test_cases = [ - ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_CHANNEL), - ExportRecipe.get_recipe(XNNPackRecipeType.INT8_STATIC_PER_TENSOR), + ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL), + ExportRecipe.get_recipe(XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR), ] for export_recipe in test_cases: @@ -95,14 +113,13 @@ def test_int8_static_quant_recipe(self) -> None: example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-1, - ) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-2 ) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_unquantized_model_outputs( + session, m_eager, example_inputs + ) def test_8a4w_recipe(self) -> None: class SimpleLinearModel(nn.Module): @@ -116,40 +133,36 @@ def forward(self, x) -> torch.Tensor: test_cases = [ ExportRecipe.get_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_CHANNEL, ), ExportRecipe.get_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, - group_size=32, + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + group_size=8, ), ] for export_recipe in test_cases: with self.subTest(export_recipe=export_recipe): - model = SimpleLinearModel() + model = SimpleLinearModel().eval() example_inputs = [(torch.randn(1, 32),)] session = export( model=model, example_inputs=example_inputs, export_recipe=export_recipe, ) - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - model(*example_inputs[0]), - atol=1e-2, - ) - ) self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_quantized_model_outputs( + session, example_inputs, 1e-3 + ) def _get_recipe_for_quant_type(self, quant_type: QuantType) -> XNNPackRecipeType: # Map QuantType to corresponding recipe name. if quant_type == QuantType.STATIC_PER_CHANNEL: - return XNNPackRecipeType.INT8_STATIC_PER_CHANNEL + return XNNPackRecipeType.PT2E_INT8_STATIC_PER_CHANNEL elif quant_type == QuantType.DYNAMIC_PER_CHANNEL: - return XNNPackRecipeType.INT8_DYNAMIC_PER_CHANNEL + return XNNPackRecipeType.PT2E_INT8_DYNAMIC_PER_CHANNEL elif quant_type == QuantType.STATIC_PER_TENSOR: - return XNNPackRecipeType.INT8_STATIC_PER_TENSOR + return XNNPackRecipeType.PT2E_INT8_STATIC_PER_TENSOR elif quant_type == QuantType.NONE: return XNNPackRecipeType.FP32 else: @@ -224,12 +237,13 @@ def test_validate_recipe_kwargs_int4_tensor_with_valid_group_size( # Should not raise any exception recipe_w_default_group = provider.create_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR ) self.assertIsNotNone(recipe_w_default_group) recipe = provider.create_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size=64 + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + group_size=64, ) self.assertIsNotNone(recipe) @@ -240,7 +254,7 @@ def test_validate_recipe_kwargs_int4_tensor_with_invalid_group_size( with self.assertRaises(ValueError) as cm: provider.create_recipe( - XNNPackRecipeType.INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, + XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR, group_size="32", # String instead of int ) diff --git a/export/__init__.py b/export/__init__.py index d5f3826ab90..a7b165185de 100644 --- a/export/__init__.py +++ b/export/__init__.py @@ -15,12 +15,19 @@ """ from .export import export, ExportSession -from .recipe import ExportRecipe, LoweringRecipe, QuantizationRecipe, RecipeType +from .recipe import ( + AOQuantizationConfig, + ExportRecipe, + LoweringRecipe, + QuantizationRecipe, + RecipeType, +) from .recipe_provider import BackendRecipeProvider from .recipe_registry import recipe_registry from .types import StageType __all__ = [ + "AOQuantizationConfig", "StageType", "ExportRecipe", "LoweringRecipe", diff --git a/export/recipe.py b/export/recipe.py index 8f7251cd419..086d57f3e38 100644 --- a/export/recipe.py +++ b/export/recipe.py @@ -6,7 +6,9 @@ from abc import ABCMeta, abstractmethod from dataclasses import dataclass from enum import Enum, EnumMeta -from typing import List, Optional, Sequence +from typing import Callable, List, Optional, Sequence + +import torch from executorch.exir._warnings import experimental @@ -64,6 +66,20 @@ class Mode(str, Enum): RELEASE = "release" +@dataclass +class AOQuantizationConfig: + """ + Configuration for torchao quantization with optional filter function. + + Attributes: + ao_base_config: The AOBaseConfig for quantization + filter_fn: Optional filter function to selectively apply quantization + """ + + ao_base_config: AOBaseConfig + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None + + @dataclass class QuantizationRecipe: """ @@ -73,11 +89,12 @@ class QuantizationRecipe: Attributes: quantizers: Optional list of quantizers for model quantization - ao_base_config: Optional list of AO base configurations + ao_quantization_configs: Optional list of AOQuantizationConfig objects that pair + AOBaseConfig with optional filter functions """ quantizers: Optional[List[Quantizer]] = None - ao_base_config: Optional[List[AOBaseConfig]] = None + ao_quantization_configs: Optional[List[AOQuantizationConfig]] = None def get_quantizers(self) -> Optional[List[Quantizer]]: """ diff --git a/export/stages.py b/export/stages.py index f4de59a9b7a..2b3f8a42440 100644 --- a/export/stages.py +++ b/export/stages.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Optional, Sequence @@ -20,7 +21,10 @@ from torch._export.pass_base import PassType from torchao.quantization import quantize_ from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e -from torchao.quantization.pt2e.quantizer import ComposableQuantizer +from torchao.quantization.pt2e.quantizer import ( + ComposableQuantizer, + Quantizer as TorchAOPT2EQuantizer, +) from torchao.utils import unwrap_tensor_subclass @@ -289,7 +293,7 @@ def run(self, artifact: PipelineArtifact) -> None: """ if ( not self._quantization_recipe - or not self._quantization_recipe.ao_base_config + or not self._quantization_recipe.ao_quantization_configs ): logging.info( "Quantization recipe is invalid to run SourceTransform, returning original artifact" @@ -300,15 +304,14 @@ def run(self, artifact: PipelineArtifact) -> None: assert isinstance(artifact.data, dict) # Store the original models - self._transformed_models = artifact.data + self._transformed_models = copy.deepcopy(artifact.data) # Apply torchao quantize_ to each model - for method_name, model in artifact.data.items(): + for _, model in artifact.data.items(): # pyre-ignore - for config in self._quantization_recipe.ao_base_config: - quantize_(model, config) + for ao_config in self._quantization_recipe.ao_quantization_configs: + quantize_(model, ao_config.ao_base_config, ao_config.filter_fn) unwrap_tensor_subclass(model) - self._transformed_models[method_name] = model self._artifact = artifact.copy_with_new_data(self._transformed_models) @@ -333,6 +336,36 @@ def valid_predecessor_stages(self) -> List["StageType"]: def can_start_pipeline(self) -> bool: return True + def _get_quantizer_for_prepare_pt2e(self, quantizers: List[Any]): + torch_ao_quantizers = [] + torchao_pt2e_quantizers = [] + + for quantizer in quantizers: + if isinstance(quantizer, TorchAOPT2EQuantizer): + torchao_pt2e_quantizers.append(quantizer) + else: + # torch.ao quantizer support will soon be deprecated, remove this once CoreML moves to torchao quantizer + logging.warning( + f"torch.ao quantizer {quantizer} is deprecated, consider moving to torchao quantizer" + ) + torch_ao_quantizers.append(quantizer) + + if torch_ao_quantizers and torchao_pt2e_quantizers: + raise ValueError("Mixed quantizer types are not supported") + if len(torch_ao_quantizers) > 1: + raise ValueError( + "Multiple quantizers of torch.ao.quantization.quantizer not supported" + ) + + if torch_ao_quantizers: + # prepare_pt2e has backward compat with torch.ao quantizer + return torch_ao_quantizers[0] + elif torchao_pt2e_quantizers: + # Multiple torchao quantizers - use ComposableQuantizer + return ComposableQuantizer(torchao_pt2e_quantizers) + else: + raise ValueError("No quantizers detected") + def run(self, artifact: PipelineArtifact) -> None: if not self._quantization_recipe or not self._quantization_recipe.quantizers: logging.info( @@ -357,11 +390,10 @@ def run(self, artifact: PipelineArtifact) -> None: inputs = example_inputs[method_name][0] captured_graph = torch.export.export(model, inputs, strict=True).module() - composed_quantizer = ComposableQuantizer( - # pyre-ignore + quantizer = self._get_quantizer_for_prepare_pt2e( self._quantization_recipe.quantizers ) - prepared_model = prepare_pt2e(captured_graph, composed_quantizer) + prepared_model = prepare_pt2e(captured_graph, quantizer) for calibration_input in example_inputs[method_name]: prepared_model(*calibration_input) diff --git a/export/tests/TARGETS b/export/tests/TARGETS index 068c3436b6a..56534140976 100644 --- a/export/tests/TARGETS +++ b/export/tests/TARGETS @@ -14,7 +14,7 @@ runtime.python_test( "//executorch/runtime:runtime", ] ) - +z runtime.python_test( name = "test_executorch_export", srcs = [ diff --git a/export/tests/test_export_session.py b/export/tests/test_export_session.py index 30288941d22..fcec1b7a59a 100644 --- a/export/tests/test_export_session.py +++ b/export/tests/test_export_session.py @@ -12,7 +12,11 @@ import torch from executorch.export import ExportRecipe, ExportSession -from executorch.export.recipe import LoweringRecipe, QuantizationRecipe +from executorch.export.recipe import ( + AOQuantizationConfig, + LoweringRecipe, + QuantizationRecipe, +) from executorch.export.stages import PipelineArtifact from executorch.export.types import StageType @@ -20,7 +24,7 @@ class SimpleTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.linear = torch.nn.Linear(10, 5) + self.linear: torch.nn.Module = torch.nn.Linear(10, 5) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -449,7 +453,7 @@ def test_pipeline_building_with_all_recipes(self) -> None: """Test pipeline building with quantization and lowering recipes.""" # Create comprehensive recipes quant_recipe = QuantizationRecipe( - ao_base_config=[Mock()], + ao_quantization_configs=[AOQuantizationConfig(Mock())], quantizers=[Mock()], ) lowering_recipe = LoweringRecipe( diff --git a/export/tests/test_export_stages.py b/export/tests/test_export_stages.py index 4820e508e18..d4629a1aea7 100644 --- a/export/tests/test_export_stages.py +++ b/export/tests/test_export_stages.py @@ -11,25 +11,25 @@ import torch from executorch.exir.program import EdgeProgramManager, ExecutorchProgramManager -from executorch.export import QuantizationRecipe +from executorch.export import AOQuantizationConfig, QuantizationRecipe, StageType from executorch.export.stages import ( EdgeTransformAndLowerStage, ExecutorchStage, PipelineArtifact, QuantizeStage, SourceTransformStage, - StageType, ToBackendStage, ToEdgeStage, TorchExportStage, ) from torch.export import ExportedProgram +from torchao.quantization.pt2e.quantizer import Quantizer as TorchAOPT2EQuantizer class SimpleTestModel(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.linear = torch.nn.Linear(10, 5) + self.linear: torch.nn.Module = torch.nn.Linear(10, 5) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) @@ -163,7 +163,7 @@ def setUp(self) -> None: def test_source_transform_stage_no_quantization(self) -> None: mock_recipe = Mock(spec=QuantizationRecipe) - mock_recipe.ao_base_config = None + mock_recipe.ao_quantization_configs = None stage = SourceTransformStage(mock_recipe) artifact = PipelineArtifact(data=self.models_dict, context={}) @@ -174,12 +174,18 @@ def test_source_transform_stage_no_quantization(self) -> None: @patch("executorch.export.stages.quantize_") @patch("executorch.export.stages.unwrap_tensor_subclass") - def test_run_with_ao_base_config( + def test_run_with_ao_quantization_configs( self, mock_unwrap: Mock, mock_quantize: Mock ) -> None: - mock_config = Mock() + from torchao.core.config import AOBaseConfig + + mock_config = Mock(spec=AOBaseConfig) + mock_filter_fn = Mock() + mock_ao_config: AOQuantizationConfig = AOQuantizationConfig( + ao_base_config=mock_config, filter_fn=mock_filter_fn + ) mock_recipe = Mock(spec=QuantizationRecipe) - mock_recipe.ao_base_config = [mock_config] + mock_recipe.ao_quantization_configs = [mock_ao_config] stage = SourceTransformStage(mock_recipe) @@ -188,7 +194,7 @@ def test_run_with_ao_base_config( stage.run(artifact) # Verify quantize_ was called with the model and config - mock_quantize.assert_called_once_with(self.model, mock_config) + mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn) # Verify unwrap_tensor_subclass was called with the model mock_unwrap.assert_called_once_with(self.model) @@ -201,6 +207,21 @@ def setUp(self) -> None: self.example_inputs = [(torch.randn(2, 10),)] self.context = {"example_inputs": {"forward": self.example_inputs}} + @staticmethod + def create_dummy_quantizer() -> TorchAOPT2EQuantizer: + + class DummyQuantizer(TorchAOPT2EQuantizer): + def __init__(self): + pass + + def annotate(self, model): + return model + + def validate(self, model): + pass + + return DummyQuantizer() + def test_run_no_quantizers(self) -> None: """Test execution with no quantizers.""" mock_recipe = Mock(spec=QuantizationRecipe) @@ -224,7 +245,7 @@ def test_run_with_quantizers( mock_convert_pt2e: Mock, ) -> None: """Test execution with quantizers""" - mock_quantizer = Mock() + mock_quantizer = self.create_dummy_quantizer() mock_recipe = Mock(spec=QuantizationRecipe) mock_recipe.quantizers = [mock_quantizer] stage = QuantizeStage(mock_recipe) @@ -285,6 +306,35 @@ def test_run_empty_example_inputs(self) -> None: "Example inputs for method forward not found or empty", str(cm.exception) ) + @patch("executorch.export.stages.ComposableQuantizer") + def test_get_quantizer_for_prepare_pt2e( + self, mock_composable_quantizer: Mock + ) -> None: + """Test _get_quantizer_for_prepare_pt2e method with different quantizer scenarios.""" + mock_recipe = Mock(spec=QuantizationRecipe) + stage = QuantizeStage(mock_recipe) + + # Test empty quantizers list - should raise ValueError + with self.assertRaises(ValueError) as cm: + stage._get_quantizer_for_prepare_pt2e([]) + self.assertIn("No quantizers detected", str(cm.exception)) + + # Test ComposableQuantizer path with multiple torchao quantizers + # Create instances of dummy quantizers using the reusable method + quantizer1 = self.create_dummy_quantizer() + quantizer2 = self.create_dummy_quantizer() + + # Set up ComposableQuantizer mock + mock_composed_quantizer = Mock() + mock_composable_quantizer.return_value = mock_composed_quantizer + + # Call the method with multiple torchao quantizers + result = stage._get_quantizer_for_prepare_pt2e([quantizer1, quantizer2]) + + # Verify ComposableQuantizer was called with the quantizers + mock_composable_quantizer.assert_called_once_with([quantizer1, quantizer2]) + self.assertEqual(result, mock_composed_quantizer) + class TestToEdgeStage(unittest.TestCase): def setUp(self) -> None: From ed02934242520f7c89a8ebc1e815edd8c8baeb06 Mon Sep 17 00:00:00 2001 From: Abhinay Kukkadapu Date: Thu, 14 Aug 2025 14:28:11 -0700 Subject: [PATCH 2/2] [executorch] Add coreml quant recipes Pull Request resolved: https://github.com/pytorch/executorch/pull/13387 Fixing tests for stack that got reverted: https://github.com/pytorch/executorch/pull/13265 Adds coreml quant recipes after FP32/16 recipes added in #13121 Recipes added: PT2E_INT8_STATIC PT2E_INT8_WEIGHT_ONLY TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP CODEBOOK_WEIGHT_ONLY ghstack-source-id: 303126085 @exported-using-ghexport Differential Revision: [D80206542](https://our.internmc.facebook.com/intern/diff/D80206542/) --- backends/apple/coreml/TARGETS | 2 + .../coreml/recipes/coreml_recipe_provider.py | 294 +++++++- .../coreml/recipes/coreml_recipe_types.py | 36 +- .../apple/coreml/test/test_coreml_recipes.py | 644 +++++++++++++----- 4 files changed, 801 insertions(+), 175 deletions(-) diff --git a/backends/apple/coreml/TARGETS b/backends/apple/coreml/TARGETS index 6993b699427..22cb20d9065 100644 --- a/backends/apple/coreml/TARGETS +++ b/backends/apple/coreml/TARGETS @@ -120,11 +120,13 @@ runtime.python_test( "test/*.py", ]), deps = [ + "fbsource//third-party/pypi/coremltools:coremltools", "fbsource//third-party/pypi/pytest:pytest", ":partitioner", ":quantizer", ":recipes", "//caffe2:torch", "//pytorch/vision:torchvision", + "fbsource//third-party/pypi/scikit-learn:scikit-learn", ], ) diff --git a/backends/apple/coreml/recipes/coreml_recipe_provider.py b/backends/apple/coreml/recipes/coreml_recipe_provider.py index 75c937027bb..90b798f9e0c 100644 --- a/backends/apple/coreml/recipes/coreml_recipe_provider.py +++ b/backends/apple/coreml/recipes/coreml_recipe_provider.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Sequence import coremltools as ct +import torch from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition.coreml_partitioner import ( @@ -18,11 +19,15 @@ from executorch.exir import EdgeCompileConfig from executorch.export import ( + AOQuantizationConfig, BackendRecipeProvider, ExportRecipe, LoweringRecipe, + QuantizationRecipe, RecipeType, ) +from torchao.quantization.granularity import PerAxis, PerGroup +from torchao.quantization.quant_api import IntxWeightOnlyConfig class CoreMLRecipeProvider(BackendRecipeProvider): @@ -50,34 +55,98 @@ def create_recipe( # Validate kwargs self._validate_recipe_kwargs(recipe_type, **kwargs) - # Parse recipe type to get precision and compute unit - precision = None if recipe_type == CoreMLRecipeType.FP32: - precision = ct.precision.FLOAT32 + return self._build_fp_recipe(recipe_type, ct.precision.FLOAT32, **kwargs) elif recipe_type == CoreMLRecipeType.FP16: - precision = ct.precision.FLOAT16 - - if precision is None: - raise ValueError(f"Unknown precision for recipe: {recipe_type.value}") + return self._build_fp_recipe(recipe_type, ct.precision.FLOAT16, **kwargs) + elif recipe_type == CoreMLRecipeType.PT2E_INT8_STATIC: + return self._build_pt2e_quantized_recipe( + recipe_type, activation_dtype=torch.quint8, **kwargs + ) + elif recipe_type == CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY: + return self._build_pt2e_quantized_recipe( + recipe_type, activation_dtype=torch.float32, **kwargs + ) + elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL: + return self._build_torchao_quantized_recipe( + recipe_type, + weight_dtype=torch.int4, + is_per_channel=True, + **kwargs, + ) + elif recipe_type == CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP: + group_size = kwargs.pop("group_size", 32) + return self._build_torchao_quantized_recipe( + recipe_type, + weight_dtype=torch.int4, + is_per_channel=False, + group_size=group_size, + **kwargs, + ) + elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL: + return self._build_torchao_quantized_recipe( + recipe_type, weight_dtype=torch.int8, is_per_channel=True, **kwargs + ) + elif recipe_type == CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP: + group_size = kwargs.pop("group_size", 32) + return self._build_torchao_quantized_recipe( + recipe_type, + weight_dtype=torch.int8, + is_per_channel=False, + group_size=group_size, + **kwargs, + ) + elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: + bits = kwargs.pop("bits") + block_size = kwargs.pop("block_size") + return self._build_codebook_quantized_recipe( + recipe_type, bits=bits, block_size=block_size, **kwargs + ) - return self._build_recipe(recipe_type, precision, **kwargs) + return None def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> None: - if not kwargs: - return - expected_keys = {"minimum_deployment_target", "compute_unit"} + """Validate kwargs for each recipe type""" + expected_keys = self._get_expected_keys(recipe_type) + unexpected = set(kwargs.keys()) - expected_keys if unexpected: raise ValueError( - f"CoreML Recipes only accept 'minimum_deployment_target' or 'compute_unit' as parameter. " - f"Unexpected parameters: {list(unexpected)}" + f"Recipe '{recipe_type.value}' received unexpected parameters: {list(unexpected)}" ) + + self._validate_base_parameters(kwargs) + self._validate_group_size_parameter(recipe_type, kwargs) + self._validate_codebook_parameters(recipe_type, kwargs) + + def _get_expected_keys(self, recipe_type: RecipeType) -> set: + """Get expected parameter keys for a recipe type""" + common_keys = {"minimum_deployment_target", "compute_unit"} + + if recipe_type in [ + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, + ]: + return common_keys | {"group_size", "filter_fn"} + elif recipe_type in [ + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, + ]: + return common_keys | {"filter_fn"} + elif recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: + return common_keys | {"bits", "block_size", "filter_fn"} + else: + return common_keys + + def _validate_base_parameters(self, kwargs: Any) -> None: + """Validate minimum_deployment_target and compute_unit parameters""" if "minimum_deployment_target" in kwargs: minimum_deployment_target = kwargs["minimum_deployment_target"] if not isinstance(minimum_deployment_target, ct.target): raise ValueError( f"Parameter 'minimum_deployment_target' must be an enum of type ct.target, got {type(minimum_deployment_target)}" ) + if "compute_unit" in kwargs: compute_unit = kwargs["compute_unit"] if not isinstance(compute_unit, ct.ComputeUnit): @@ -85,12 +154,79 @@ def _validate_recipe_kwargs(self, recipe_type: RecipeType, **kwargs: Any) -> Non f"Parameter 'compute_unit' must be an enum of type ct.ComputeUnit, got {type(compute_unit)}" ) - def _build_recipe( + def _validate_group_size_parameter( + self, recipe_type: RecipeType, kwargs: Any + ) -> None: + """Validate group_size parameter for applicable recipe types""" + if ( + recipe_type + in [ + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, + ] + and "group_size" in kwargs + ): + group_size = kwargs["group_size"] + if not isinstance(group_size, int): + raise ValueError( + f"Parameter 'group_size' must be an integer, got {type(group_size).__name__}: {group_size}" + ) + if group_size <= 0: + raise ValueError( + f"Parameter 'group_size' must be positive, got: {group_size}" + ) + + def _validate_codebook_parameters( + self, recipe_type: RecipeType, kwargs: Any + ) -> None: + """Validate bits and block_size parameters for codebook recipe type""" + if recipe_type != CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY: + return + + # Both bits and block_size must be present + if not ("bits" in kwargs and "block_size" in kwargs): + raise ValueError( + "Parameters 'bits' and 'block_size' must be present for codebook recipes" + ) + + if "bits" in kwargs: + bits = kwargs["bits"] + if not isinstance(bits, int): + raise ValueError( + f"Parameter 'bits' must be an integer, got {type(bits).__name__}: {bits}" + ) + if not (1 <= bits <= 8): + raise ValueError( + f"Parameter 'bits' must be between 1 and 8, got: {bits}" + ) + + if "block_size" in kwargs: + block_size = kwargs["block_size"] + if not isinstance(block_size, list): + raise ValueError( + f"Parameter 'block_size' must be a list, got {type(block_size).__name__}: {block_size}" + ) + + def _validate_and_set_deployment_target( + self, kwargs: Any, min_target: ct.target, quantization_type: str + ) -> None: + """Validate or set minimum deployment target for quantization recipes""" + minimum_deployment_target = kwargs.get("minimum_deployment_target", None) + if minimum_deployment_target and minimum_deployment_target < min_target: + raise ValueError( + f"minimum_deployment_target must be {str(min_target)} or higher for {quantization_type} quantization" + ) + else: + # Default to the minimum target for this quantization type + kwargs["minimum_deployment_target"] = min_target + + def _build_fp_recipe( self, recipe_type: RecipeType, precision: ct.precision, **kwargs: Any, ) -> ExportRecipe: + """Build FP32/FP16 recipe""" lowering_recipe = self._get_coreml_lowering_recipe( compute_precision=precision, **kwargs, @@ -98,18 +234,142 @@ def _build_recipe( return ExportRecipe( name=recipe_type.value, - quantization_recipe=None, # TODO - add quantization recipe + lowering_recipe=lowering_recipe, + ) + + def _build_pt2e_quantized_recipe( + self, + recipe_type: RecipeType, + activation_dtype: torch.dtype, + **kwargs: Any, + ) -> ExportRecipe: + """Build PT2E-based quantization recipe""" + from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer + + self._validate_and_set_deployment_target(kwargs, ct.target.iOS17, "pt2e") + + # Validate activation_dtype + assert activation_dtype in [ + torch.quint8, + torch.float32, + ], f"activation_dtype must be torch.quint8 or torch.float32, got {activation_dtype}" + + # Create quantization config + config = ct.optimize.torch.quantization.LinearQuantizerConfig( + global_config=ct.optimize.torch.quantization.ModuleLinearQuantizerConfig( + quantization_scheme="symmetric", + activation_dtype=activation_dtype, + weight_dtype=torch.qint8, + weight_per_channel=True, + ) + ) + + quantizer = CoreMLQuantizer(config) + quantization_recipe = QuantizationRecipe(quantizers=[quantizer]) + + lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) + + return ExportRecipe( + name=recipe_type.value, + quantization_recipe=quantization_recipe, + lowering_recipe=lowering_recipe, + ) + + def _build_torchao_quantized_recipe( + self, + recipe_type: RecipeType, + weight_dtype: torch.dtype, + is_per_channel: bool, + group_size: int = 32, + **kwargs: Any, + ) -> ExportRecipe: + """Build TorchAO-based quantization recipe""" + if is_per_channel: + weight_granularity = PerAxis(axis=0) + else: + weight_granularity = PerGroup(group_size=group_size) + + # Use user-provided filter_fn if provided + filter_fn = kwargs.get("filter_fn", None) + config = AOQuantizationConfig( + ao_base_config=IntxWeightOnlyConfig( + weight_dtype=weight_dtype, + granularity=weight_granularity, + ), + filter_fn=filter_fn, + ) + + quantization_recipe = QuantizationRecipe( + quantizers=None, + ao_quantization_configs=[config], + ) + + # override minimum_deployment_target to ios18 for torchao (GH issue #13122) + self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "torchao") + lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) + + return ExportRecipe( + name=recipe_type.value, + quantization_recipe=quantization_recipe, + lowering_recipe=lowering_recipe, + ) + + def _build_codebook_quantized_recipe( + self, + recipe_type: RecipeType, + bits: int, + block_size: list, + **kwargs: Any, + ) -> ExportRecipe: + """Build codebook/palettization quantization recipe""" + from torchao.prototype.quantization.codebook_coreml import ( + CodebookWeightOnlyConfig, + ) + + self._validate_and_set_deployment_target(kwargs, ct.target.iOS18, "codebook") + + # Get the appropriate dtype (torch.uint1 through torch.uint8) + dtype = getattr(torch, f"uint{bits}") + + # Use user-provided filter_fn or default to Linear/Embedding layers + filter_fn = kwargs.get( + "filter_fn", + lambda m, fqn: ( + isinstance(m, torch.nn.Embedding) or isinstance(m, torch.nn.Linear) + ), + ) + + config = AOQuantizationConfig( + ao_base_config=CodebookWeightOnlyConfig( + dtype=dtype, + block_size=block_size, + ), + filter_fn=filter_fn, + ) + + quantization_recipe = QuantizationRecipe( + quantizers=None, + ao_quantization_configs=[config], + ) + + lowering_recipe = self._get_coreml_lowering_recipe(**kwargs) + + return ExportRecipe( + name=recipe_type.value, + quantization_recipe=quantization_recipe, lowering_recipe=lowering_recipe, ) def _get_coreml_lowering_recipe( self, - compute_precision: ct.precision, + compute_precision: ct.precision = ct.precision.FLOAT16, **kwargs: Any, ) -> LoweringRecipe: + """Get CoreML lowering recipe with optional precision""" compile_specs = CoreMLBackend.generate_compile_specs( compute_precision=compute_precision, - **kwargs, + compute_unit=kwargs.get("compute_unit", ct.ComputeUnit.ALL), + minimum_deployment_target=kwargs.get("minimum_deployment_target", None), ) minimum_deployment_target = kwargs.get("minimum_deployment_target", None) diff --git a/backends/apple/coreml/recipes/coreml_recipe_types.py b/backends/apple/coreml/recipes/coreml_recipe_types.py index 77f808bd982..fc7292c3c58 100644 --- a/backends/apple/coreml/recipes/coreml_recipe_types.py +++ b/backends/apple/coreml/recipes/coreml_recipe_types.py @@ -12,14 +12,42 @@ class CoreMLRecipeType(RecipeType): """CoreML-specific generic recipe types""" - # FP32 generic recipe, defaults to values published by the CoreML backend and partitioner - # Precision = FP32, Default compute_unit = All (can be overriden by kwargs) + ## All the recipes accept common kwargs + # 1. minimum_deployment_unit (default: None) + # 2. compute_unit (default: ct.ComputeUnit.ALL) + + # FP32 precision recipe, defaults to values published by the CoreML backend and partitioner FP32 = "coreml_fp32" - # FP16 generic recipe, defaults to values published by the CoreML backend and partitioner - # Precision = FP32, Default compute_unit = All (can be overriden by kwargs) + # FP16 precision recipe, defaults to values published by the CoreML backend and partitioner FP16 = "coreml_fp16" + ## PT2E-based quantization recipes + # INT8 Static Quantization (weights + activations), requires calibration dataset + PT2E_INT8_STATIC = "coreml_pt2e_int8_static" + # INT8 Weight-only Quantization (activations remain FP32) + PT2E_INT8_WEIGHT_ONLY = "coreml_pt2e_int8_weight_only" + + ## TorchAO-based quantization recipes + # All TorchAO recipes accept filter_fn kwarg to control which layers are quantized + # INT4 Weight-only Quantization, per-channel (axis=0) + # Additional kwargs: filter_fn (default: Embedding and linear layers) + TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL = "coreml_torchao_int4_weight_only_per_channel" + # INT4 Weight-only Quantization, per-group + # Additional kwargs: group_size (default: 32), filter_fn (default: Embedding and linear layers) + TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP = "coreml_torchao_int4_weight_only_per_group" + # INT8 Weight-only Quantization, per-channel (axis=0) + # Additional kwargs: filter_fn (default: Embedding and linear layers) + TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL = "coreml_torchao_int8_weight_only_per_channel" + # INT8 Weight-only Quantization, per-group + # Additional kwargs: group_size (default: 32), filter_fn (default: Embedding and linear layers) + TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP = "coreml_torchao_int8_weight_only_per_group" + + ## Codebook/Palettization Quantization + # Additional mandatory kwargs: bits (range: 1-8), block_size (list of ints), + # filter_fn (default: targets Linear and Embedding layers) + CODEBOOK_WEIGHT_ONLY = "coreml_codebook_weight_only" + @classmethod def get_backend_name(cls) -> str: return COREML_BACKEND diff --git a/backends/apple/coreml/test/test_coreml_recipes.py b/backends/apple/coreml/test/test_coreml_recipes.py index ca5c6c30c9c..7a78836b2bc 100644 --- a/backends/apple/coreml/test/test_coreml_recipes.py +++ b/backends/apple/coreml/test/test_coreml_recipes.py @@ -4,11 +4,10 @@ import unittest -from typing import List import coremltools as ct - import torch + from executorch.backends.apple.coreml.recipes import ( CoreMLRecipeProvider, CoreMLRecipeType, @@ -17,19 +16,16 @@ from executorch.backends.apple.coreml.test.test_coreml_utils import ( IS_VALID_TEST_RUNTIME, ) -from executorch.exir.schema import DelegateCall, Program -from executorch.export import export, ExportRecipe, recipe_registry +from executorch.exir.schema import DelegateCall +from executorch.export import export, ExportRecipe, recipe_registry, StageType + from torch import nn from torch.testing._internal.common_quantization import TestHelperModules +from torchao.quantization.utils import compute_error class TestCoreMLRecipes(unittest.TestCase): - fp32_recipes: List[CoreMLRecipeType] = [ - CoreMLRecipeType.FP32, - ] - fp16_recipes: List[CoreMLRecipeType] = [ - CoreMLRecipeType.FP16, - ] + """Test suite for CoreML recipes focusing on quantization functionality""" def setUp(self): torch._dynamo.reset() @@ -41,198 +37,538 @@ def setUp(self): def tearDown(self): super().tearDown() - def check_fully_delegated(self, program: Program) -> None: + def check_fully_delegated(self, session) -> None: + """Helper to verify a program is fully delegated to CoreML""" + session.print_delegation_info() + program = session.get_executorch_program() instructions = program.execution_plan[0].chains[0].instructions assert instructions is not None self.assertEqual(len(instructions), 1) self.assertIsInstance(instructions[0].instr_args, DelegateCall) - def test_all_fp32_recipes_with_simple_model(self): - """Test all FP32 recipes with a simple linear model""" - for recipe_type in self.fp32_recipes: - with self.subTest(recipe=recipe_type.value): - m_eager = TestHelperModules.TwoLinearModule().eval() - example_inputs = [(torch.randn(9, 8),)] + def _compare_eager_quantized_model_outputs(self, session, example_inputs, atol): + """Utility to compare eager quantized model output with session output after coreml lowering""" + if IS_VALID_TEST_RUNTIME: + source_transform_output = session.get_stage_artifacts()[ + StageType.SOURCE_TRANSFORM + ] + eager_quantized_model = source_transform_output.data["forward"] + output = session.run_method("forward", example_inputs[0])[0] + expected = eager_quantized_model(*example_inputs[0]) + self.assertTrue(torch.allclose(output, expected, atol=atol)) + + def _compare_eager_unquantized_model_outputs( + self, session, eager_unquantized_model, example_inputs, sqnr_threshold=20 + ): + """Utility to compare eager unquantized model output with session output using SQNR""" + if IS_VALID_TEST_RUNTIME: + quantized_output = session.run_method("forward", example_inputs[0])[0] + original_output = eager_unquantized_model(*example_inputs[0]) + error = compute_error(original_output, quantized_output) + print(f"SQNR: {error} dB") + self.assertTrue(error > sqnr_threshold) + + def test_fp32_recipe(self): + """Test FP32 recipe functionality""" + model = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe(CoreMLRecipeType.FP32), + ) + self.check_fully_delegated(session) + + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-3) + self._compare_eager_unquantized_model_outputs(session, model, example_inputs) + + def test_fp16_recipe(self): + """Test FP16 recipe functionality""" + model = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe(CoreMLRecipeType.FP16), + ) + self.check_fully_delegated(session) + + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-3) + self._compare_eager_unquantized_model_outputs(session, model, example_inputs) + + def test_fp_recipes_with_custom_parameters(self): + """Test FP recipes with custom deployment target and compute unit""" + test_cases = [ + (CoreMLRecipeType.FP32, {"minimum_deployment_target": ct.target.iOS16}), + (CoreMLRecipeType.FP16, {"compute_unit": ct.ComputeUnit.CPU_ONLY}), + ] + + model = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + + for recipe_type, kwargs in test_cases: + with self.subTest(recipe=recipe_type.value, kwargs=kwargs): session = export( - model=m_eager, + model=model, example_inputs=example_inputs, - export_recipe=ExportRecipe.get_recipe(recipe_type), - ) - self.check_fully_delegated(session.get_executorch_program()) - - # Verify outputs match - if IS_VALID_TEST_RUNTIME: - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-3, - ) - ) + export_recipe=ExportRecipe.get_recipe(recipe_type, **kwargs), + ) + self.check_fully_delegated(session) + + def test_int4_weight_only_per_channel(self): + """Test INT4 weight-only per-channel quantization""" + model = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe( + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL + ), + ) + self.check_fully_delegated(session) + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-02) + self._compare_eager_unquantized_model_outputs(session, model, example_inputs) - def test_all_fp16_recipes_with_simple_model(self): - """Test all FP16 recipes with a simple linear model""" + def test_int4_weight_only_per_group(self): + """Test INT4 weight-only per-group quantization with different group sizes""" - for recipe_type in self.fp16_recipes: - with self.subTest(recipe=recipe_type.value): - m_eager = TestHelperModules.TwoLinearModule().eval() - example_inputs = [(torch.randn(9, 8),)] + class CustomTwoLinearModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(32, 32) + self.layer2 = nn.Linear(32, 8) + def forward(self, x): + x = torch.relu(self.layer1(x)) + x = self.layer2(x) + return x + + model = CustomTwoLinearModel().eval() + example_inputs = [(torch.randn(1, 32),)] + # Test with different group sizes + for group_size in [8, 16, 32]: + with self.subTest(group_size=group_size): session = export( - model=m_eager, + model=model, example_inputs=example_inputs, - export_recipe=ExportRecipe.get_recipe(recipe_type), + export_recipe=ExportRecipe.get_recipe( + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, + group_size=group_size, + ), ) + self.check_fully_delegated(session) - self.check_fully_delegated(session.get_executorch_program()) + self._compare_eager_quantized_model_outputs( + session, example_inputs, atol=1e-3 + ) + self._compare_eager_unquantized_model_outputs( + session, model, example_inputs + ) - # Verify outputs match (slightly higher tolerance for FP16) - if IS_VALID_TEST_RUNTIME: - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - m_eager(*example_inputs[0]), - atol=1e-3, - ) - ) + def test_int4_weight_only_per_group_validation(self): + """Test INT4 per-group parameter validation""" + # Test invalid group size type + with self.assertRaises(ValueError) as cm: + self.provider.create_recipe( + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, group_size="32" + ) + self.assertIn("must be an integer", str(cm.exception)) - def test_custom_simple_model(self): - """Test with a custom simple model""" + # Test negative group size + with self.assertRaises(ValueError) as cm: + self.provider.create_recipe( + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, group_size=-1 + ) + self.assertIn("must be positive", str(cm.exception)) - class CustomTestModel(nn.Module): + # Test unexpected parameter + with self.assertRaises(ValueError) as cm: + self.provider.create_recipe( + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, + group_size=32, # group_size not valid for per-channel + ) + self.assertIn("unexpected parameters", str(cm.exception)) + + def test_int8_weight_only_per_channel(self): + """Test INT8 weight-only per-channel quantization""" + model = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe( + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL + ), + ) + self.check_fully_delegated(session) + + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-2) + self._compare_eager_unquantized_model_outputs(session, model, example_inputs) + + def test_int8_weight_only_per_group(self): + """Test INT8 weight-only per-group quantization with different group sizes""" + + class SimpleLinearModel(nn.Module): def __init__(self): super().__init__() - self.linear1 = nn.Linear(10, 20) - self.relu = nn.ReLU() - self.linear2 = nn.Linear(20, 1) + self.layer = nn.Linear(64, 2) def forward(self, x): - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - return x + return self.layer(x) - model = CustomTestModel().eval() - example_inputs = [(torch.randn(1, 10),)] - for recipe_type in self.fp32_recipes + self.fp16_recipes: - with self.subTest(recipe=recipe_type.value): + model = SimpleLinearModel().eval() + example_inputs = [(torch.randn(1, 64),)] + + # Test with different group sizes + for group_size in [16, 32, 64]: + with self.subTest(group_size=group_size): session = export( model=model, example_inputs=example_inputs, - export_recipe=ExportRecipe.get_recipe(recipe_type), - ) - session.print_delegation_info() - self.check_fully_delegated(session.get_executorch_program()) - - if IS_VALID_TEST_RUNTIME: - self.assertTrue( - torch.allclose( - session.run_method("forward", example_inputs[0])[0], - model(*example_inputs[0]), - atol=1e-3, - ) - ) - - def test_unsupported_recipe_type(self): - """Test that unsupported recipe types return None""" - from executorch.export import RecipeType + export_recipe=ExportRecipe.get_recipe( + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, + group_size=group_size, + ), + ) + self.check_fully_delegated(session) - class UnsupportedRecipeType(RecipeType): - UNSUPPORTED = "unsupported" + self._compare_eager_quantized_model_outputs( + session, example_inputs, atol=1e-2 + ) + self._compare_eager_unquantized_model_outputs( + session, model, example_inputs + ) - @classmethod - def get_backend_name(cls) -> str: - return "dummy" + def test_codebook_weight_only_recipe(self): + """Test codebook quantization recipe""" - recipe = self.provider.create_recipe(UnsupportedRecipeType.UNSUPPORTED) - self.assertIsNone(recipe) + class SimpleLinearModel(nn.Module): + def __init__(self): + super().__init__() + self.layer = nn.Linear(32, 2) - def test_recipe_registry_integration(self): - """Test that recipes work with the global recipe registry""" - for recipe_type in self.fp32_recipes + self.fp16_recipes: - with self.subTest(recipe=recipe_type.value): - recipe = ExportRecipe.get_recipe(recipe_type) - self.assertIsNotNone(recipe) - self.assertEqual(recipe.name, recipe_type.value) + def forward(self, x): + return self.layer(x) - def test_invalid_recipe_kwargs(self): - """Test detailed error messages for invalid kwargs""" - provider = CoreMLRecipeProvider() + model = SimpleLinearModel().eval() + example_inputs = [(torch.randn(1, 32),)] - # Test single invalid parameter - with self.assertRaises(ValueError) as cm: - provider.create_recipe(CoreMLRecipeType.FP16, invalid_param=123) + # Test different block sizes + test_cases = [ + {"bits": 3, "block_size": [-1, 8]}, + ] - error_msg = str(cm.exception) - self.assertIn("Unexpected parameters", error_msg) + for kwargs in test_cases: + with self.subTest(kwargs=kwargs): + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, **kwargs + ), + ) + self.check_fully_delegated(session) - # Test multiple invalid parameters + def test_codebook_parameter_validation(self): + """Test codebook parameter validation""" + # Test invalid bits type with self.assertRaises(ValueError) as cm: - provider.create_recipe( - CoreMLRecipeType.FP32, param1="value1", param2="value2" + self.provider.create_recipe( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, bits="3", block_size=[-1, 8] ) + self.assertIn("must be an integer", str(cm.exception)) - error_msg = str(cm.exception) - self.assertIn("Unexpected parameters", error_msg) + # Test bits out of range + with self.assertRaises(ValueError) as cm: + self.provider.create_recipe( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, bits=0, block_size=[-1, 8] + ) + self.assertIn("must be between 1 and 8", str(cm.exception)) - # Test mix of valid and invalid parameters with self.assertRaises(ValueError) as cm: - provider.create_recipe( - CoreMLRecipeType.FP32, - minimum_deployment_target=ct.target.iOS16, # valid - invalid_param="invalid", # invalid + self.provider.create_recipe( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, bits=9, block_size=[-1, 8] ) + self.assertIn("must be between 1 and 8", str(cm.exception)) - error_msg = str(cm.exception) - self.assertIn("Unexpected parameters", error_msg) + # Test invalid block_size type + with self.assertRaises(ValueError) as cm: + self.provider.create_recipe( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, bits=3, block_size="[-1, 16]" + ) + self.assertIn("must be a list", str(cm.exception)) - def test_valid_kwargs(self): - """Test valid kwargs""" - recipe = self.provider.create_recipe( - CoreMLRecipeType.FP32, - minimum_deployment_target=ct.target.iOS16, - compute_unit=ct.ComputeUnit.CPU_AND_GPU, - ) - self.assertIsNotNone(recipe) - self.assertEqual(recipe.name, "coreml_fp32") + def test_int8_static_quantization(self): + """Test INT8 static quantization (weights + activations)""" - # Verify partitioners are properly configured - partitioners = recipe.lowering_recipe.partitioners - self.assertEqual(len(partitioners), 1, "Expected exactly one partitioner") + class SimpleLinearModel(nn.Module): + def __init__(self): + super().__init__() + self.layer1 = nn.Linear(32, 16) + self.layer2 = nn.Linear(16, 2) - # Verify delegation spec and compile specs - delegation_spec = partitioners[0].delegation_spec - self.assertIsNotNone(delegation_spec, "Delegation spec should not be None") + def forward(self, x): + x = torch.relu(self.layer1(x)) + x = self.layer2(x) + return x - compile_specs = delegation_spec.compile_specs - self.assertIsNotNone(compile_specs, "Compile specs should not be None") + model = SimpleLinearModel().eval() + example_inputs = [(torch.randn(1, 32),)] - spec_dict = {spec.key: spec.value for spec in compile_specs} + recipe = ExportRecipe.get_recipe( + CoreMLRecipeType.PT2E_INT8_STATIC, minimum_deployment_target=ct.target.iOS17 + ) - # Assert that all expected specs are present with correct values - self.assertIn( - "min_deployment_target", - spec_dict, - "minimum_deployment_target should be in compile specs", + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=recipe, ) - min_target_value = spec_dict["min_deployment_target"] - if isinstance(min_target_value, bytes): - min_target_value = min_target_value.decode("utf-8") - self.assertEqual( - str(min_target_value), - str(ct.target.iOS16.value), - "minimum_deployment_target should match the provided value", + self.check_fully_delegated(session) + + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-3) + self._compare_eager_unquantized_model_outputs(session, model, example_inputs) + + def test_int8_weight_only_pt2e(self): + """Test PT2E-based INT8 weight-only quantization""" + model = TestHelperModules.TwoLinearModule().eval() + example_inputs = [(torch.randn(9, 8),)] + + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe( + CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY + ), ) + self.check_fully_delegated(session) - self.assertIn( - "compute_units", spec_dict, "compute_unit should be in compile specs" - ) - compute_unit_value = spec_dict["compute_units"] - if isinstance(compute_unit_value, bytes): - compute_unit_value = compute_unit_value.decode("utf-8") - self.assertEqual( - str(compute_unit_value), - ct.ComputeUnit.CPU_AND_GPU.name.lower(), - "compute_unit should match the provided value", + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-2) + self._compare_eager_unquantized_model_outputs(session, model, example_inputs) + + def test_int8_weight_only_pt2e_with_conv(self): + """Test PT2E-based INT8 weight-only quantization with convolution layers""" + + class ConvModel(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(3, 16, 3, padding=1) + self.conv2 = nn.Conv2d(16, 32, 3, padding=1) + self.pool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(32, 10) + + def forward(self, x): + x = torch.relu(self.conv1(x)) + x = torch.relu(self.conv2(x)) + x = self.pool(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + model = ConvModel().eval() + example_inputs = [(torch.randn(1, 3, 32, 32),)] + + session = export( + model=model, + example_inputs=example_inputs, + export_recipe=ExportRecipe.get_recipe( + CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY + ), ) + self.check_fully_delegated(session) + + self._compare_eager_quantized_model_outputs(session, example_inputs, atol=1e-2) + self._compare_eager_unquantized_model_outputs(session, model, example_inputs) + + def test_pt2e_recipes_parameter_rejection(self): + """Test that PT2E recipes reject TorchAO-specific parameters""" + # PT2E recipes should reject TorchAO-specific parameters + pt2e_recipes = [ + CoreMLRecipeType.PT2E_INT8_STATIC, + CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY, + ] + torchao_params = ["filter_fn", "group_size", "bits", "block_size"] + + for recipe_type in pt2e_recipes: + for param in torchao_params: + with self.subTest(recipe=recipe_type.value, param=param): + kwargs = {param: "dummy_value"} + with self.assertRaises(ValueError) as cm: + self.provider.create_recipe(recipe_type, **kwargs) + self.assertIn("unexpected parameters", str(cm.exception).lower()) + + def test_filter_fn_comprehensive(self): + """Comprehensive test for filter_fn parameter functionality""" + + def custom_filter(module, fqn): + return isinstance(module, nn.Linear) and "target" in fqn + + # Test 1: TorchAO recipes accept filter_fn and default to None + torchao_recipes = [ + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, + ] + + for recipe_type in torchao_recipes: + with self.subTest(f"{recipe_type.value}_default"): + # Test default behavior (None) + recipe = self.provider.create_recipe(recipe_type) + config = recipe.quantization_recipe.ao_quantization_configs[0] + self.assertIsNone(config.filter_fn) + + with self.subTest(f"{recipe_type.value}_custom"): + # Test custom filter_fn + recipe = self.provider.create_recipe( + recipe_type, filter_fn=custom_filter + ) + config = recipe.quantization_recipe.ao_quantization_configs[0] + self.assertEqual(config.filter_fn, custom_filter) + + # Test 2: Codebook recipe accepts filter_fn and has sensible default + with self.subTest("codebook_default"): + recipe = self.provider.create_recipe( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, bits=3, block_size=[-1, 16] + ) + config = recipe.quantization_recipe.ao_quantization_configs[0] + self.assertIsNotNone(config.filter_fn) + + # Test default filter targets Linear and Embedding layers + linear_module = nn.Linear(10, 5) + embedding_module = nn.Embedding(100, 10) + conv_module = nn.Conv2d(3, 16, 3) + + self.assertTrue(config.filter_fn(linear_module, "linear")) + self.assertTrue(config.filter_fn(embedding_module, "embedding")) + self.assertFalse(config.filter_fn(conv_module, "conv")) + + with self.subTest("codebook_custom"): + recipe = self.provider.create_recipe( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, + filter_fn=custom_filter, + bits=3, + block_size=[-1, 16], + ) + config = recipe.quantization_recipe.ao_quantization_configs[0] + self.assertEqual(config.filter_fn, custom_filter) + + def test_quantization_recipe_structure(self): + """Test that quantization recipes have proper structure""" + quantization_recipes = [ + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, + ] + + for recipe_type in quantization_recipes: + with self.subTest(recipe=recipe_type.value): + kwargs = ( + {"bits": 3, "block_size": [-1, 16]} + if recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY + else {} + ) + recipe = self.provider.create_recipe(recipe_type, **kwargs) + self.assertIsNotNone(recipe) + + # Should have quantization recipe with ao_quantization_configs + self.assertIsNotNone(recipe.quantization_recipe) + self.assertIsNotNone(recipe.quantization_recipe.ao_quantization_configs) + self.assertEqual( + len(recipe.quantization_recipe.ao_quantization_configs), 1 + ) + + # Should have lowering recipe + self.assertIsNotNone(recipe.lowering_recipe) + self.assertIsNotNone(recipe.lowering_recipe.partitioners) + + def test_recipe_creation_with_defaults(self): + """Test that recipes work with default parameters""" + # Test that all recipes can be created without explicit parameters + all_recipes = [ + CoreMLRecipeType.FP32, + CoreMLRecipeType.FP16, + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, # should use default group_size=32 + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, # should use default group_size=32 + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, # should use default bits=3, block_size=[-1,16] + ] + + for recipe_type in all_recipes: + with self.subTest(recipe=recipe_type.value): + kwargs = ( + {"bits": 3, "block_size": [-1, 16]} + if recipe_type == CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY + else {} + ) + recipe = self.provider.create_recipe(recipe_type, **kwargs) + self.assertIsNotNone(recipe) + self.assertEqual(recipe.name, recipe_type.value) + + def test_minimum_deployment_target_validation(self): + """Test that minimum_deployment_target validation works correctly for quantization recipes""" + test_cases = [ + (CoreMLRecipeType.PT2E_INT8_STATIC, ct.target.iOS17, {}), + (CoreMLRecipeType.PT2E_INT8_WEIGHT_ONLY, ct.target.iOS17, {}), + ( + CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_CHANNEL, + ct.target.iOS18, + {}, + ), + (CoreMLRecipeType.TORCHAO_INT4_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}), + ( + CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_CHANNEL, + ct.target.iOS18, + {}, + ), + (CoreMLRecipeType.TORCHAO_INT8_WEIGHT_ONLY_PER_GROUP, ct.target.iOS18, {}), + ( + CoreMLRecipeType.CODEBOOK_WEIGHT_ONLY, + ct.target.iOS18, + {"bits": 3, "block_size": [-1, 16]}, + ), + ] + + for recipe_type, min_target, kwargs in test_cases: + with self.subTest(recipe=recipe_type.value): + + # Test 1: Providing deployment target below minimum should raise ValueError + too_low_target = ct.target.iOS15 + with self.assertRaises(ValueError) as cm: + self.provider.create_recipe( + recipe_type, minimum_deployment_target=too_low_target, **kwargs + ) + error_msg = str(cm.exception) + self.assertIn( + f"minimum_deployment_target must be {str(min_target)} or higher", + error_msg, + ) + + # Test 2: Providing valid deployment target should work + valid_recipe = self.provider.create_recipe( + recipe_type, minimum_deployment_target=min_target, **kwargs + ) + self.assertIsNotNone(valid_recipe) + + # Test 3: Not providing deployment target should default to minimum + default_recipe = self.provider.create_recipe(recipe_type, **kwargs) + self.assertIsNotNone(default_recipe) + + # Test 4: Providing deployment target higher than minimum should work + higher_target = ( + ct.target.iOS18 + if min_target == ct.target.iOS17 + else ct.target.iOS18 + ) + higher_recipe = self.provider.create_recipe( + recipe_type, minimum_deployment_target=higher_target, **kwargs + ) + self.assertIsNotNone(higher_recipe)