-
Notifications
You must be signed in to change notification settings - Fork 605
Add export recipes for xnnpack (#12069) #12070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
abhinaykukkadapu
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
abhinaykukkadapu:export-D77414795
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") | ||
|
||
oncall("executorch") | ||
|
||
runtime.python_library( | ||
name = "xnnpack_recipes", | ||
srcs = [ | ||
"recipes.py", | ||
], | ||
visibility = [ | ||
"//executorch/...", | ||
"@EXECUTORCH_CLIENTS", | ||
], | ||
deps = [ | ||
"//caffe2:torch", | ||
"//executorch/exir:lib", | ||
"//executorch/export:recipe", | ||
"//executorch/backends/transforms:duplicate_dynamic_quant_chain", | ||
"//executorch/backends/xnnpack/quantizer:xnnpack_quantizer", | ||
"//executorch/backends/xnnpack/partition:xnnpack_partitioner", | ||
], | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
from functools import partial | ||
from typing import Any, Callable | ||
|
||
from executorch.backends.xnnpack.partition.config.xnnpack_config import ( | ||
ConfigPrecisionType, | ||
) | ||
|
||
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner | ||
|
||
from executorch.backends.xnnpack.quantizer.xnnpack_quantizer import ( | ||
get_symmetric_quantization_config, | ||
XNNPACKQuantizer, | ||
) | ||
from executorch.backends.xnnpack.utils.configs import get_xnnpack_edge_compile_config | ||
from executorch.export.recipe import ExportRecipe, QuantizationRecipe | ||
from torchao.quantization.quant_api import int8_dynamic_activation_int4_weight | ||
|
||
|
||
def get_fp32_recipe() -> ExportRecipe: | ||
return ExportRecipe( | ||
name="fp32", | ||
quantization_recipe=None, | ||
partitioners=[XnnpackPartitioner()], | ||
) | ||
|
||
|
||
def get_quant_recipe(quant_recipe_name: str, is_per_channel: bool, is_dynamic: bool, is_qat:bool=False, **_kwargs: Any) -> ExportRecipe: | ||
# Create quantizer | ||
quantizer = XNNPACKQuantizer() | ||
operator_config = get_symmetric_quantization_config( | ||
is_per_channel=is_per_channel, is_dynamic=is_dynamic, is_qat=is_qat | ||
) | ||
quantizer.set_global(operator_config) | ||
|
||
# Create quantization recipe | ||
quant_recipe = QuantizationRecipe( | ||
quantizers=[quantizer], | ||
) | ||
|
||
config_precision = (ConfigPrecisionType.DYNAMIC_QUANT if is_dynamic else ConfigPrecisionType.STATIC_QUANT) | ||
|
||
# Create export recipe | ||
return ExportRecipe( | ||
name=quant_recipe_name, | ||
quantization_recipe=quant_recipe, | ||
partitioners=[XnnpackPartitioner(config_precision=config_precision)], | ||
edge_compile_config=get_xnnpack_edge_compile_config(), | ||
) | ||
|
||
|
||
def get_8a4w_config(group_size: int = 32) -> ExportRecipe: | ||
# Create quantization recipe | ||
quant_recipe = QuantizationRecipe( | ||
quantizers=None, | ||
ao_base_config=[ | ||
int8_dynamic_activation_int4_weight(group_size=group_size), | ||
], | ||
) | ||
|
||
# Create export recipe | ||
return ExportRecipe( | ||
name="8a4w_quant", | ||
quantization_recipe=quant_recipe, | ||
partitioners=[XnnpackPartitioner()], | ||
) | ||
|
||
|
||
RECIPE_MAP: dict[str, Callable[..., ExportRecipe]] = { | ||
"FP32_RECIPE": get_fp32_recipe, | ||
"QUANT_RECIPE": get_quant_recipe, | ||
"DYNAMIC_PER_CHANNEL_QUANT_RECIPE": partial(get_quant_recipe, "dynamic_per_channel_quant", is_per_channel=True, is_dynamic=True), | ||
"STATIC_PER_CHANNEL_QUANT_RECIPE": partial(get_quant_recipe, "static_per_channel_quant", is_per_channel=True, is_dynamic=False), | ||
"STATIC_PER_TENSOR_QUANT_RECIPE": partial(get_quant_recipe, "static_per_tensor_quant",is_per_channel=False, is_dynamic=False), | ||
"8A4W_ACCELERATED_RECIPE": get_8a4w_config, | ||
} | ||
|
||
|
||
def get_xnnpack_recipe(recipe_name: str, **kwargs: Any) -> ExportRecipe: | ||
assert recipe_name in RECIPE_MAP, f"Recipe {recipe_name} not found." | ||
return RECIPE_MAP[recipe_name](**kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,192 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what would be a really good test is if we used these recipes on all of the example_models from aot_compiler: https://github.com/pytorch/executorch/blob/main/examples/xnnpack/__init__.py#L29-L48 this would be a nice test to make sure these recipes all work. |
||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import unittest | ||
|
||
import torch | ||
from executorch.backends.xnnpack import get_xnnpack_recipe | ||
from executorch.exir.schema import DelegateCall, Program | ||
from executorch.export import export | ||
from torch import nn | ||
from torch.testing._internal.common_quantization import TestHelperModules | ||
from torchvision import models | ||
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights | ||
from executorch.backends.xnnpack.test.tester import Tester | ||
from torchvision.models.segmentation import deeplabv3, deeplabv3_resnet50 # @manual | ||
|
||
|
||
class TestXnnpackRecipes(unittest.TestCase): | ||
def setUp(self) -> None: | ||
torch._dynamo.reset() | ||
super().setUp() | ||
|
||
def tearDown(self) -> None: | ||
super().tearDown() | ||
|
||
def check_fully_delegated(self, program: Program) -> None: | ||
instructions = program.execution_plan[0].chains[0].instructions | ||
assert instructions is not None | ||
self.assertEqual(len(instructions), 1) | ||
self.assertIsInstance(instructions[0].instr_args, DelegateCall) | ||
|
||
def test_basic_recipe(self) -> None: | ||
m_eager = TestHelperModules.TwoLinearModule().eval() | ||
example_inputs = [(torch.randn(9, 8),)] | ||
session = export( | ||
model=m_eager, | ||
example_inputs=example_inputs, | ||
export_recipe=get_xnnpack_recipe("FP32_RECIPE"), | ||
) | ||
self.assertTrue( | ||
torch.allclose( | ||
session.run_method("forward", example_inputs[0])[0], | ||
m_eager(*example_inputs[0]), | ||
atol=1e-1, | ||
) | ||
) | ||
self.check_fully_delegated(session.get_executorch_program()) | ||
|
||
def test_dynamic_quant_recipe(self) -> None: | ||
with torch.no_grad(): | ||
m_eager = TestHelperModules.TwoLinearModule().eval() | ||
example_inputs = [(torch.randn(9, 8),)] | ||
session = export( | ||
model=m_eager, | ||
example_inputs=example_inputs, | ||
export_recipe=get_xnnpack_recipe( | ||
"DYNAMIC_PER_CHANNEL_QUANT_RECIPE" | ||
), | ||
) | ||
self.assertTrue( | ||
torch.allclose( | ||
session.run_method("forward", example_inputs[0])[0], | ||
m_eager(*example_inputs[0]), | ||
atol=1e-1, | ||
) | ||
) | ||
self.check_fully_delegated(session.get_executorch_program()) | ||
|
||
def test_static_quant_recipe(self) -> None: | ||
with torch.no_grad(): | ||
m_eager = TestHelperModules.TwoLinearModule().eval() | ||
example_inputs = [(torch.randn(9, 8),)] | ||
session = export( | ||
model=m_eager, | ||
example_inputs=example_inputs, | ||
export_recipe=get_xnnpack_recipe( | ||
"STATIC_PER_CHANNEL_QUANT_RECIPE" | ||
), | ||
) | ||
self.assertTrue( | ||
torch.allclose( | ||
session.run_method("forward", example_inputs[0])[0], | ||
m_eager(*example_inputs[0]), | ||
atol=1e-1, | ||
) | ||
) | ||
self.check_fully_delegated(session.get_executorch_program()) | ||
|
||
def test_8a4w_recipe(self) -> None: | ||
class SimpleLinearModel(nn.Module): | ||
def __init__(self) -> None: | ||
super(SimpleLinearModel, self).__init__() | ||
self.layer1 = nn.Linear(32, 2) | ||
|
||
def forward(self, x) -> torch.Tensor: | ||
x = self.layer1(x) | ||
return x | ||
|
||
model = SimpleLinearModel() | ||
example_inputs = [(torch.randn(1, 32),)] | ||
session = export( | ||
model=model, | ||
example_inputs=example_inputs, | ||
export_recipe=get_xnnpack_recipe( | ||
"8A4W_ACCELERATED_RECIPE", group_size=32 | ||
), | ||
) | ||
self.assertTrue( | ||
torch.allclose( | ||
session.run_method("forward", example_inputs[0])[0], | ||
model(*example_inputs[0]), | ||
atol=1e-1, | ||
) | ||
) | ||
self.check_fully_delegated(session.get_executorch_program()) | ||
|
||
def test_mv3_model(self) -> None: | ||
mv3 = models.mobilenetv3.mobilenet_v3_small(pretrained=True) | ||
mv3 = mv3.eval() | ||
model_inputs = [(torch.randn(1, 3, 224, 224),)] | ||
self.assertTrue(hasattr(mv3, "forward")) | ||
dynamic_shapes =({2: torch.export.Dim("height", min=224, max=455), 3: torch.export.Dim("width", min=224, max=455)},) | ||
session = export( | ||
model=mv3, | ||
example_inputs=model_inputs, | ||
dynamic_shapes=dynamic_shapes, | ||
export_recipe=get_xnnpack_recipe( | ||
"STATIC_PER_CHANNEL_QUANT_RECIPE" | ||
), | ||
) | ||
|
||
Tester._assert_outputs_equal( | ||
session.run_method("forward", model_inputs[0])[0], | ||
mv3(*model_inputs[0]), | ||
atol=1e-3, | ||
) | ||
|
||
def test_mv2_model_with_static_quant_recipe(self) -> None: | ||
mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights) | ||
mv2 = mv2.eval() | ||
model_inputs = [(torch.randn(1, 3, 224, 224),)] | ||
self.assertTrue(hasattr(mv2, "forward")) | ||
dynamic_shapes =({2: torch.export.Dim("height", min=224, max=455), 3: torch.export.Dim("width", min=224, max=455)},) | ||
session = export( | ||
model=mv2, | ||
example_inputs=model_inputs, | ||
dynamic_shapes=dynamic_shapes, | ||
export_recipe=get_xnnpack_recipe( | ||
"STATIC_PER_CHANNEL_QUANT_RECIPE" | ||
), | ||
) | ||
|
||
Tester._assert_outputs_equal( | ||
session.run_method("forward", model_inputs[0])[0], | ||
mv2(*model_inputs[0]), | ||
atol=1e-3, | ||
) | ||
|
||
def test_dl3_with_recipe(self) -> None: | ||
class DL3Wrapper(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.m = deeplabv3_resnet50( | ||
weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT | ||
) | ||
|
||
def forward(self, *args): | ||
return self.m(*args)["out"] | ||
|
||
dl3 = DL3Wrapper() | ||
dl3 = dl3.eval() | ||
model_inputs = [(torch.randn(1, 3, 224, 224),)] | ||
self.assertTrue(hasattr(dl3, "forward")) | ||
session = export( | ||
model=dl3, | ||
example_inputs=model_inputs, | ||
export_recipe=get_xnnpack_recipe( | ||
"STATIC_PER_CHANNEL_QUANT_RECIPE" | ||
), | ||
) | ||
|
||
Tester._assert_outputs_equal( | ||
session.run_method("forward", model_inputs[0])[0], | ||
dl3(*model_inputs[0]), | ||
atol=1e-3, | ||
) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.