From 57e241d3dc7855c9fcc9642452191cd37be225d9 Mon Sep 17 00:00:00 2001 From: Justin Ngo Date: Tue, 22 Apr 2025 00:18:57 +0000 Subject: [PATCH] [TOSA] Add TosaLayerwiseConstantFoldPass and TosaReduceTransposes passes Add the following passes to TorchBackendToTosaBackendPipeline: - TosaLayerwiseConstantFoldPass: fold full-layer operations on TOSA consts - TosaReduceTransposes: remove unnecessary TOSA transposes to reduce data movements Signed-off-by: Justin Ngo Change-Id: I44636a7392a571e57c2c3a2c0316835d2d2ca938 --- lib/Dialect/TorchConversion/Transforms/Passes.cpp | 6 ++++++ projects/pt1/e2e_testing/xfail_sets.py | 9 +++++++++ 2 files changed, 15 insertions(+) diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index ee75c9678b72..3ffbdb642db9 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -117,6 +117,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( const TorchConversion::TosaBackendPipelineOptions &options) { pm.addNestedPass( createConvertTorchToTosaPass(options.requireFullTosaConversion)); + // Fold full-layer operations on TOSA constants + pm.addNestedPass(createTosaLayerwiseConstantFoldPass()); + + // Perform transpose reductions for avoidable data movements + pm.addNestedPass(createTosaReduceTransposes()); + // Perform rank broadcasting so TosaToLinalg pass works pm.addNestedPass(createTosaMakeBroadcastablePass()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b42339ae9a25..66ccc39907d6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1731,6 +1731,15 @@ "HBC_basic", # 1D inputs cause generated tosa.negate ops to crash downstream "NllLossModule_1D_basic", + # BertModule is not crashing, but is timing out due to TosaLayerwiseConstantFoldPass: + # Exception ignored on calling ctypes callback function: .consume_return_funcs at 0x765783f12c20> + # Traceback (most recent call last): + # File "torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py", line 101, in consume_return_funcs + # def consume_return_funcs(*args): + # File "torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 316, in handle_timeout + # raise TimeoutError(self.error_message) + # TimeoutError: Timeout + "BertModule_basic", } # Write the TOSA set as a "passing" set as it is very early in development