diff --git a/backends/arm/_passes/decompose_sqrt_pass.py b/backends/arm/_passes/decompose_sqrt_pass.py index d4a678affea..547d0091e90 100644 --- a/backends/arm/_passes/decompose_sqrt_pass.py +++ b/backends/arm/_passes/decompose_sqrt_pass.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe +from typing import Tuple, Union + import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass @@ -15,7 +17,7 @@ ) -def get_sqrt_decomposition(op) -> tuple: +def get_sqrt_decomposition(op) -> Union[Tuple, torch._ops.OpOverload]: # TODO : "MLETORCH-863 : Replace current sqrt -> pow.Tensor_Scalar workaround with pow.Tensor_Tensor" if op in edge_sqrt_ops: return exir_ops.edge.aten.pow.Tensor_Scalar diff --git a/backends/arm/_passes/replace_scalar_with_tensor_pass.py b/backends/arm/_passes/replace_scalar_with_tensor_pass.py index f9ad3cfaee8..1e8b2d6b651 100644 --- a/backends/arm/_passes/replace_scalar_with_tensor_pass.py +++ b/backends/arm/_passes/replace_scalar_with_tensor_pass.py @@ -6,7 +6,7 @@ # pyre-unsafe -from typing import Dict +from typing import Dict, Union import torch from executorch.backends.transforms.replace_scalar_with_tensor import ( @@ -18,7 +18,10 @@ # Operators that are included for both TOSA profiles -_common_ops: Dict[EdgeOpOverload, EdgeOpOverload] = { +_common_ops: Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], +] = { exir_ops.edge.aten.add.Scalar: exir_ops.edge.aten.add.Tensor, exir_ops.edge.aten.sub.Scalar: exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.mul.Scalar: exir_ops.edge.aten.mul.Tensor, diff --git a/backends/transforms/replace_scalar_with_tensor.py b/backends/transforms/replace_scalar_with_tensor.py index 21eb325b646..8ce05a3d4d4 100644 --- a/backends/transforms/replace_scalar_with_tensor.py +++ b/backends/transforms/replace_scalar_with_tensor.py @@ -5,7 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Dict, Optional +from typing import Dict, Optional, Union import torch from executorch.exir.dialects._ops import ops as exir_ops @@ -32,7 +32,12 @@ class ReplaceScalarWithTensorArgPass(ExportPass): def __init__( self, - scalar_to_tensor_ops: Optional[Dict[EdgeOpOverload, EdgeOpOverload]] = None, + scalar_to_tensor_ops: Optional[ + Dict[ + Union[EdgeOpOverload, torch._ops.OpOverload], + Union[EdgeOpOverload, torch._ops.OpOverload], + ] + ] = None, ): if scalar_to_tensor_ops is not None: self.scalar_to_tensor_ops = scalar_to_tensor_ops