Skip to content

Add export recipes for xnnpack (#12069) #12070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/xnnpack/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ runtime.python_library(
":xnnpack_preprocess",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
"//executorch/backends/xnnpack/utils:xnnpack_utils",
"//executorch/backends/xnnpack/recipes:xnnpack_recipes"
],
)
3 changes: 2 additions & 1 deletion backends/xnnpack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
XnnpackDynamicallyQuantizedPartitioner,
XnnpackPartitioner,
)
from .recipes.recipes import get_xnnpack_recipe

# Exposed Configs in XNNPACK Package
from .utils.configs import (
Expand All @@ -23,12 +24,12 @@
# XNNPACK Backend
from .xnnpack_preprocess import XnnpackBackend


__all__ = [
"XnnpackDynamicallyQuantizedPartitioner",
"XnnpackPartitioner",
"XnnpackBackend",
"capture_graph_for_xnnpack",
"get_xnnpack_recipe",
"get_xnnpack_capture_config",
"get_xnnpack_edge_compile_config",
"get_xnnpack_executorch_backend_config",
Expand Down
22 changes: 22 additions & 0 deletions backends/xnnpack/recipes/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")

runtime.python_library(
name = "xnnpack_recipes",
srcs = [
"recipes.py",
],
visibility = [
"//executorch/...",
"@EXECUTORCH_CLIENTS",
],
deps = [
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/export:recipe",
"//executorch/backends/transforms:duplicate_dynamic_quant_chain",
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer",
"//executorch/backends/xnnpack/partition:xnnpack_partitioner",
],
)
87 changes: 87 additions & 0 deletions backends/xnnpack/recipes/recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from functools import partial
from typing import Any, Callable

from executorch.backends.xnnpack.partition.config.xnnpack_config import (
ConfigPrecisionType,
)

from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import (
get_symmetric_quantization_config,
XNNPACKQuantizer,
)
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config
from executorch.export.recipe import ExportRecipe, QuantizationRecipe
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight


def get_fp32_recipe() -> ExportRecipe:
return ExportRecipe(
name="fp32",
quantization_recipe=None,
partitioners=[XnnpackPartitioner()],
)


def get_quant_recipe(quant_recipe_name: str, is_per_channel: bool, is_dynamic: bool, is_qat:bool=False, **_kwargs: Any) -> ExportRecipe:
# Create quantizer
quantizer = XNNPACKQuantizer()
operator_config = get_symmetric_quantization_config(
is_per_channel=is_per_channel, is_dynamic=is_dynamic, is_qat=is_qat
)
quantizer.set_global(operator_config)

# Create quantization recipe
quant_recipe = QuantizationRecipe(
quantizers=[quantizer],
)

config_precision = (ConfigPrecisionType.DYNAMIC_QUANT if is_dynamic else ConfigPrecisionType.STATIC_QUANT)

# Create export recipe
return ExportRecipe(
name=quant_recipe_name,
quantization_recipe=quant_recipe,
partitioners=[XnnpackPartitioner(config_precision=config_precision)],
edge_compile_config=get_xnnpack_edge_compile_config(),
)


def get_8a4w_config(group_size: int = 32) -> ExportRecipe:
# Create quantization recipe
quant_recipe = QuantizationRecipe(
quantizers=None,
ao_base_config=[
int8_dynamic_activation_int4_weight(group_size=group_size),
],
)

# Create export recipe
return ExportRecipe(
name="8a4w_quant",
quantization_recipe=quant_recipe,
partitioners=[XnnpackPartitioner()],
)


RECIPE_MAP: dict[str, Callable[..., ExportRecipe]] = {
"FP32_RECIPE": get_fp32_recipe,
"QUANT_RECIPE": get_quant_recipe,
"DYNAMIC_PER_CHANNEL_QUANT_RECIPE": partial(get_quant_recipe, "dynamic_per_channel_quant", is_per_channel=True, is_dynamic=True),
"STATIC_PER_CHANNEL_QUANT_RECIPE": partial(get_quant_recipe, "static_per_channel_quant", is_per_channel=True, is_dynamic=False),
"STATIC_PER_TENSOR_QUANT_RECIPE": partial(get_quant_recipe, "static_per_tensor_quant",is_per_channel=False, is_dynamic=False),
"8A4W_ACCELERATED_RECIPE": get_8a4w_config,
}


def get_xnnpack_recipe(recipe_name: str, **kwargs: Any) -> ExportRecipe:
assert recipe_name in RECIPE_MAP, f"Recipe {recipe_name} not found."
return RECIPE_MAP[recipe_name](**kwargs)
13 changes: 13 additions & 0 deletions backends/xnnpack/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,16 @@ runtime.python_test(
"libtorch",
],
)

runtime.python_test(
name = "test_xnnpack_recipes",
srcs = glob([
"recipes/*.py",
]),
deps = [
"//executorch/backends/xnnpack:xnnpack_delegate",
"//executorch/export:lib",
"//pytorch/vision:torchvision", # @manual,
"//executorch/backends/xnnpack/test/tester:tester",
],
)
192 changes: 192 additions & 0 deletions backends/xnnpack/test/recipes/test_xnnpack_recipes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what would be a really good test is if we used these recipes on all of the example_models from aot_compiler:

https://github.com/pytorch/executorch/blob/main/examples/xnnpack/__init__.py#L29-L48

this would be a nice test to make sure these recipes all work.

# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest

import torch
from executorch.backends.xnnpack import get_xnnpack_recipe
from executorch.exir.schema import DelegateCall, Program
from executorch.export import export
from torch import nn
from torch.testing._internal.common_quantization import TestHelperModules
from torchvision import models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from executorch.backends.xnnpack.test.tester import Tester
from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50 # @manual


class TestXnnpackRecipes(unittest.TestCase):
def setUp(self) -> None:
torch._dynamo.reset()
super().setUp()

def tearDown(self) -> None:
super().tearDown()

def check_fully_delegated(self, program: Program) -> None:
instructions = program.execution_plan[0].chains[0].instructions
assert instructions is not None
self.assertEqual(len(instructions), 1)
self.assertIsInstance(instructions[0].instr_args, DelegateCall)

def test_basic_recipe(self) -> None:
m_eager = TestHelperModules.TwoLinearModule().eval()
example_inputs = [(torch.randn(9, 8),)]
session = export(
model=m_eager,
example_inputs=example_inputs,
export_recipe=get_xnnpack_recipe("FP32_RECIPE"),
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-1,
)
)
self.check_fully_delegated(session.get_executorch_program())

def test_dynamic_quant_recipe(self) -> None:
with torch.no_grad():
m_eager = TestHelperModules.TwoLinearModule().eval()
example_inputs = [(torch.randn(9, 8),)]
session = export(
model=m_eager,
example_inputs=example_inputs,
export_recipe=get_xnnpack_recipe(
"DYNAMIC_PER_CHANNEL_QUANT_RECIPE"
),
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-1,
)
)
self.check_fully_delegated(session.get_executorch_program())

def test_static_quant_recipe(self) -> None:
with torch.no_grad():
m_eager = TestHelperModules.TwoLinearModule().eval()
example_inputs = [(torch.randn(9, 8),)]
session = export(
model=m_eager,
example_inputs=example_inputs,
export_recipe=get_xnnpack_recipe(
"STATIC_PER_CHANNEL_QUANT_RECIPE"
),
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
m_eager(*example_inputs[0]),
atol=1e-1,
)
)
self.check_fully_delegated(session.get_executorch_program())

def test_8a4w_recipe(self) -> None:
class SimpleLinearModel(nn.Module):
def __init__(self) -> None:
super(SimpleLinearModel, self).__init__()
self.layer1 = nn.Linear(32, 2)

def forward(self, x) -> torch.Tensor:
x = self.layer1(x)
return x

model = SimpleLinearModel()
example_inputs = [(torch.randn(1, 32),)]
session = export(
model=model,
example_inputs=example_inputs,
export_recipe=get_xnnpack_recipe(
"8A4W_ACCELERATED_RECIPE", group_size=32
),
)
self.assertTrue(
torch.allclose(
session.run_method("forward", example_inputs[0])[0],
model(*example_inputs[0]),
atol=1e-1,
)
)
self.check_fully_delegated(session.get_executorch_program())

def test_mv3_model(self) -> None:
mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True)
mv3 = mv3.eval()
model_inputs = [(torch.randn(1, 3, 224, 224),)]
self.assertTrue(hasattr(mv3, "forward"))
dynamic_shapes =({2: torch.export.Dim("height", min=224, max=455), 3: torch.export.Dim("width", min=224, max=455)},)
session = export(
model=mv3,
example_inputs=model_inputs,
dynamic_shapes=dynamic_shapes,
export_recipe=get_xnnpack_recipe(
"STATIC_PER_CHANNEL_QUANT_RECIPE"
),
)

Tester._assert_outputs_equal(
session.run_method("forward", model_inputs[0])[0],
mv3(*model_inputs[0]),
atol=1e-3,
)

def test_mv2_model_with_static_quant_recipe(self) -> None:
mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights)
mv2 = mv2.eval()
model_inputs = [(torch.randn(1, 3, 224, 224),)]
self.assertTrue(hasattr(mv2, "forward"))
dynamic_shapes =({2: torch.export.Dim("height", min=224, max=455), 3: torch.export.Dim("width", min=224, max=455)},)
session = export(
model=mv2,
example_inputs=model_inputs,
dynamic_shapes=dynamic_shapes,
export_recipe=get_xnnpack_recipe(
"STATIC_PER_CHANNEL_QUANT_RECIPE"
),
)

Tester._assert_outputs_equal(
session.run_method("forward", model_inputs[0])[0],
mv2(*model_inputs[0]),
atol=1e-3,
)

def test_dl3_with_recipe(self) -> None:
class DL3Wrapper(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = deeplabv3_resnet50(
weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT
)

def forward(self, *args):
return self.m(*args)["out"]

dl3 = DL3Wrapper()
dl3 = dl3.eval()
model_inputs = [(torch.randn(1, 3, 224, 224),)]
self.assertTrue(hasattr(dl3, "forward"))
session = export(
model=dl3,
example_inputs=model_inputs,
export_recipe=get_xnnpack_recipe(
"STATIC_PER_CHANNEL_QUANT_RECIPE"
),
)

Tester._assert_outputs_equal(
session.run_method("forward", model_inputs[0])[0],
dl3(*model_inputs[0]),
atol=1e-3,
)

15 changes: 8 additions & 7 deletions export/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

from .recipe import ExportRecipe

from torch._export.pass_base import PassType
from executorch.exir.program._program import _transform


class Stage(ABC):
"""
Expand Down Expand Up @@ -95,9 +98,7 @@ class ExportStage(Stage):

def __init__(
self,
pre_edge_transform_passes: Optional[
Callable[[ExportedProgram], ExportedProgram]
] = None,
pre_edge_transform_passes: Optional[List[PassType]] = None,
) -> None:
self._exported_program: Dict[str, ExportedProgram] = {}
self._pre_edge_transform_passes = pre_edge_transform_passes
Expand Down Expand Up @@ -153,10 +154,10 @@ def run(
)

# Apply pre-edge transform passes if available
if self._pre_edge_transform_passes is not None:
for pre_edge_transform_pass in self._pre_edge_transform_passes:
self._exported_program[method_name] = pre_edge_transform_pass(
self._exported_program[method_name]
if pre_edge_transform_passes:= self._pre_edge_transform_passes or []:
for pass_ in pre_edge_transform_passes:
self._exported_program[method_name] = _transform(
self._exported_program[method_name], pass_
)

def get_artifacts(self) -> Dict[str, ExportedProgram]:
Expand Down
Loading
Loading