diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml index 37fbaf9588..bfb5102c51 100644 --- a/.github/workflows/check-pl-compat.yaml +++ b/.github/workflows/check-pl-compat.yaml @@ -39,7 +39,7 @@ jobs: git checkout $(git tag | sort -V | tail -1) - if: ${{ inputs.catalyst == 'release-candidate' }} run: | - git checkout v0.8.0-rc + git checkout v0.8.1-rc - name: Install deps run: | diff --git a/doc/dev/jax_integration.rst b/doc/dev/jax_integration.rst index cbbbd733b6..4d427977cb 100644 --- a/doc/dev/jax_integration.rst +++ b/doc/dev/jax_integration.rst @@ -88,7 +88,6 @@ that doesn't work with Catalyst includes: - ``jax.numpy.polyfit`` - ``jax.numpy.fft`` -- ``jax.scipy.linalg`` - ``jax.numpy.ndarray.at[index]`` when ``index`` corresponds to all array indices. diff --git a/doc/releases/changelog-0.8.1.md b/doc/releases/changelog-0.8.1.md index f8b7d5b94c..62ac8ac216 100644 --- a/doc/releases/changelog-0.8.1.md +++ b/doc/releases/changelog-0.8.1.md @@ -2,14 +2,91 @@

New features

+* The `catalyst.mitigate_with_zne` error mitigation compilation pass now supports + the option to fold gates locally as well as the existing method of globally. + [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006) + [(#1129)](https://github.com/PennyLaneAI/catalyst/pull/1129) + + While global folding applies the scale factor by forming the inverse of the + entire quantum circuit (without measurements) and repeating + the circuit with its inverse, local folding instead inserts per-gate folding sequences directly in place + of each gate in the original circuit. + + For example, + + ```python + import jax + import pennylane as qml + from catalyst import qjit, mitigate_with_zne + from pennylane.transforms import exponential_extrapolate + + dev = qml.device("lightning.qubit", wires=4, shots=5) + + @qml.qnode(dev) + def circuit(): + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + return qml.expval(qml.PauliY(wires=0)) + + @qjit(keep_intermediate=True) + def mitigated_circuit(): + s = jax.numpy.array([1, 2, 3]) + return mitigate_with_zne( + circuit, + scale_factors=s, + extrapolate=exponential_extrapolate, + folding="local-all" # "local-all" for local on all gates or "global" for the original method (default being "global") + )() + ``` + + ```pycon + >>> circuit() + >>> mitigated_circuit() + ``` +

Improvements

+* Fixes an issue where certain JAX linear algebra functions from `jax.scipy.linalg` gave incorrect + results when invoked from within a qjit block, and adds full support for other `jax.scipy.linalg` + functions. + [(#1097)](https://github.com/PennyLaneAI/catalyst/pull/1097) + + The supported linear algebra functions include, but are not limited to: + + - [`jax.scipy.linalg.cholesky`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html) + - [`jax.scipy.linalg.expm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html) + - [`jax.scipy.linalg.funm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html) + - [`jax.scipy.linalg.hessenberg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html) + - [`jax.scipy.linalg.lu`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html) + - [`jax.scipy.linalg.lu_solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html) + - [`jax.scipy.linalg.polar`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html) + - [`jax.scipy.linalg.qr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html) + - [`jax.scipy.linalg.schur`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html) + - [`jax.scipy.linalg.solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html) + - [`jax.scipy.linalg.sqrtm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html) + - [`jax.scipy.linalg.svd`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html) +

Breaking changes

-

Deprecations

+* The argument `scale_factors` of `mitigate_with_zne` function now follows the proper literature + definition. It now needs to be a list of positive odd integers, as we don't support the fractional + part. + [(#1120)](https://github.com/PennyLaneAI/catalyst/pull/1120)

Bug fixes

+* Those functions calling the `gather_p` primitive (like `jax.scipy.linalg.expm`) + can now be used in multiple qjits in a single program. + [(#1096)](https://github.com/PennyLaneAI/catalyst/pull/1096) +

Contributors

This release contains contributions from (in alphabetical order): + +Joey Carter, +Alessandro Cosentino, +Paul Haochen Wang, +David Ittah, +Romain Moyard, +Daniel Strano, +Raul Torres. diff --git a/frontend/catalyst/_version.py b/frontend/catalyst/_version.py index 3a1257d76c..13fef47bd0 100644 --- a/frontend/catalyst/_version.py +++ b/frontend/catalyst/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.8.0" +__version__ = "0.8.1" diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py index 84c1982548..c3b67b11d7 100644 --- a/frontend/catalyst/api_extensions/error_mitigation.py +++ b/frontend/catalyst/api_extensions/error_mitigation.py @@ -30,6 +30,10 @@ from catalyst.jax_primitives import Folding, zne_p +def _is_odd_positive(numbers_list): + return all(isinstance(i, int) and i > 0 and i % 2 != 0 for i in numbers_list) + + ## API ## def mitigate_with_zne( fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding="global" @@ -47,7 +51,7 @@ def mitigate_with_zne( Args: fn (qml.QNode): the circuit to be mitigated. - scale_factors (array[int]): the range of noise scale factors used. + scale_factors (list[int]): the range of noise scale factors used. extrapolate (Callable): A qjit-compatible function taking two sequences as arguments (scale factors, and results), and returning a float by performing a fitting procedure. By default, perfect polynomial fitting :func:`~.polynomial_extrapolate` will be used, @@ -56,6 +60,7 @@ def mitigate_with_zne( function. folding (str): Unitary folding technique to be used to scale the circuit. Possible values: - global: the global unitary of the input circuit is folded + - local-all: per-gate folding sequences replace original gates in-place in the circuit Returns: Callable: A callable object that computes the mitigated of the wrapped :class:`~.QNode` @@ -113,10 +118,11 @@ def workflow(weights, s): return zne_circuit(weights) >>> weights = jnp.ones([3, 2, 3]) - >>> scale_factors = jnp.array([1, 2, 3]) + >>> scale_factors = [1, 3, 5] >>> workflow(weights, scale_factors) Array(-0.19946598, dtype=float64) """ + kwargs = copy.copy(locals()) kwargs.pop("fn") @@ -128,7 +134,12 @@ def workflow(weights, s): elif extrapolate_kwargs is not None: extrapolate = functools.partial(extrapolate, **extrapolate_kwargs) - return ZNE(fn, scale_factors, extrapolate, folding) + if not _is_odd_positive(scale_factors): + raise ValueError("The scale factors must be positive odd integers: {scale_factors}") + + num_folds = jnp.array([jnp.floor((s - 1) / 2) for s in scale_factors], dtype=int) + + return ZNE(fn, num_folds, extrapolate, folding) ## IMPL ## @@ -147,7 +158,7 @@ class ZNE: def __init__( self, fn: Callable, - scale_factors: jnp.ndarray, + num_folds: jnp.ndarray, extrapolate: Callable[[Sequence[float], Sequence[float]], float], folding: str, ): @@ -155,7 +166,7 @@ def __init__( raise TypeError(f"A QNode is expected, got the classical function {fn}") self.fn = fn self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}" - self.scale_factors = scale_factors + self.num_folds = num_folds self.extrapolate = extrapolate self.folding = folding @@ -175,14 +186,12 @@ def __call__(self, *args, **kwargs): except ValueError as e: raise ValueError(f"Folding type must be one of {list(map(str, Folding))}") from e # TODO: remove the following check once #755 is completed - if folding != Folding.GLOBAL: + if folding == Folding.RANDOM: raise NotImplementedError(f"Folding type {folding.value} is being developed") - results = zne_p.bind( - *args_data, self.scale_factors, folding=folding, jaxpr=jaxpr, fn=self.fn - ) - float_scale_factors = jnp.array(self.scale_factors, dtype=float) - results = self.extrapolate(float_scale_factors, results[0]) + results = zne_p.bind(*args_data, self.num_folds, folding=folding, jaxpr=jaxpr, fn=self.fn) + float_num_folds = jnp.array(self.num_folds, dtype=float) + results = self.extrapolate(float_num_folds, results[0]) # Single measurement if results.shape == (): return results diff --git a/frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py b/frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py deleted file mode 100644 index 5029e4d508..0000000000 --- a/frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2022-2024 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This module contains warnings for using jax.scipy.linalg functions inside qjit. -Due to improperly linked lapack symbols, occasionally these functions give wrong -numerical results when used in a qjit context. -As for now, we warn users to wrap all of these with a catalyst.accelerate() callback. -This patch should be removed when we have proper linkage to lapack. -See: - https://app.shortcut.com/xanaduai/story/70899/find-a-system-to-automatically-create-a-custom-call-library-from-the-one-in-jax - https://github.com/PennyLaneAI/catalyst/issues/753 - https://github.com/PennyLaneAI/catalyst/issues/1071 -""" - -import warnings - -import jax - -from catalyst.tracing.contexts import AccelerateContext - - -class JaxLinalgWarner: - def __init__(self, fn): - self.fn = fn - - def __call__(self, *args, **kwargs): - if not AccelerateContext.am_inside_accelerate(): - warnings.warn( - f""" - jax.scipy.linalg.{self.fn.__name__} occasionally gives wrong numerical results - when used within a qjit-compiled function. - See https://github.com/PennyLaneAI/catalyst/issues/1071. - In the meantime, we recommend catalyst.accelerate to call - the underlying {self.fn.__name__} function directly: - - @qjit - def f(A): - return catalyst.accelerate(jax.scipy.linalg.{self.fn.__name__})(A) - - See https://docs.pennylane.ai/projects/catalyst/en/latest/code/api/catalyst.accelerate.html - """ - ) - return (self.fn)(*args, **kwargs) diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index 2ad97c4e04..f2c2e93cde 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -20,11 +20,14 @@ import jax from jax._src.lax.lax import _nary_lower_hlo from jax._src.lax.slicing import ( + _argnum_weak_type, + _gather_dtype_rule, _gather_shape_computation, _is_sorted, _no_duplicate_dims, _rank, _sorted_dims_in_range, + standard_primitive, ) from jax._src.lib.mlir.dialects import hlo from jax.core import AbstractValue, Tracer, concrete_aval @@ -35,6 +38,7 @@ "_gather_shape_rule_dynamic", "_sin_lowering2", "_cos_lowering2", + "gather2_p", ) @@ -186,6 +190,16 @@ def _gather_shape_rule_dynamic( return _gather_shape_computation(indices, dimension_numbers, slice_sizes) +# TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is +# applied. +gather2_p = standard_primitive( + _gather_shape_rule_dynamic, + _gather_dtype_rule, + "gather", + weak_type_rule=_argnum_weak_type(0), +) + + def _sin_lowering2(ctx, x): """Use hlo.sine lowering instead of the new sin lowering from jax 0.4.28""" return _nary_lower_hlo(hlo.sine, ctx, x) diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index 97db7386ab..01fbcefd04 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -45,12 +45,7 @@ ) from jax._src.lax.control_flow import _initial_style_jaxpr from jax._src.lax.lax import _abstractify, cos_p, sin_p -from jax._src.lax.slicing import ( - _argnum_weak_type, - _gather_dtype_rule, - _gather_lower, - standard_primitive, -) +from jax._src.lax.slicing import _gather_lower from jax._src.linear_util import annotate from jax._src.pjit import _extract_implicit_args, _flat_axes_specs from jax._src.source_info_util import current as jax_current @@ -99,8 +94,8 @@ from catalyst.jax_extras.patches import ( _cos_lowering2, - _gather_shape_rule_dynamic, _sin_lowering2, + gather2_p, get_aval2, ) from catalyst.logging import debug_logger @@ -514,14 +509,6 @@ def abstractify(args, kwargs): in_type = infer_lambda_input_type(axes_specs, flat_args) return in_type, in_tree - # TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is - # applied. - gather2_p = standard_primitive( - _gather_shape_rule_dynamic, - _gather_dtype_rule, - "gather", - weak_type_rule=_argnum_weak_type(0), - ) register_lowering(gather2_p, _gather_lower) # TBD diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index 649c4d4a75..ec38c0551e 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -227,8 +227,8 @@ class Folding(Enum): """ GLOBAL = "global" - RANDOM = "random" - ALL = "all" + RANDOM = "local-random" + ALL = "local-all" ############## @@ -930,7 +930,7 @@ def _folding_attribute(ctx, folding): ctx = ctx.module_context.context return ir.OpaqueAttr.get( "mitigation", - ("folding " + Folding(folding).value).encode("utf-8"), + ("folding " + Folding(folding).name.lower()).encode("utf-8"), ir.NoneType.get(ctx), ctx, ) @@ -950,13 +950,13 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn): symbol_name = func_op.name.value output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) - scale_factors = args[-1] + num_folds = args[-1] return ZneOp( flat_output_types, ir.FlatSymbolRefAttr.get(symbol_name), mlir.flatten_lowering_ir_args(args[0:-1]), _folding_attribute(ctx, folding), - scale_factors, + num_folds, ).results diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py index 88209b4d8c..7b2ae5d8a8 100644 --- a/frontend/catalyst/jit.py +++ b/frontend/catalyst/jit.py @@ -35,7 +35,6 @@ from catalyst.compiled_functions import CompilationCache, CompiledFunction from catalyst.compiler import CompileOptions, Compiler from catalyst.debug.instruments import instrument -from catalyst.jax_extras.jax_scipy_linalg_warnings import JaxLinalgWarner from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr from catalyst.logging import debug_logger, debug_logger_init from catalyst.passes import _inject_transform_named_sequence @@ -590,15 +589,6 @@ def closure(qnode, *args, **kwargs): with Patcher( (qml.QNode, "__call__", closure), - # !!! TODO: fix jax.scipy numerical failures with properly fetched lapack calls - # As of now, we raise a warning prompting the user to use a callback with catalyst.accelerate() - # https://app.shortcut.com/xanaduai/story/70899/find-a-system-to-automatically-create-a-custom-call-library-from-the-one-in-jax - # https://github.com/PennyLaneAI/catalyst/issues/753 - # https://github.com/PennyLaneAI/catalyst/issues/1071 - (jax.scipy.linalg, "expm", JaxLinalgWarner(jax.scipy.linalg.expm)), - (jax.scipy.linalg, "lu", JaxLinalgWarner(jax.scipy.linalg.lu)), - (jax.scipy.linalg, "lu_factor", JaxLinalgWarner(jax.scipy.linalg.lu_factor)), - (jax.scipy.linalg, "lu_solve", JaxLinalgWarner(jax.scipy.linalg.lu_solve)), ): # TODO: improve PyTree handling diff --git a/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp new file mode 100644 index 0000000000..b150eb34e2 --- /dev/null +++ b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp @@ -0,0 +1,941 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file has been modified from its original form in the JAX project at +// https://github.com/google/jax released under the Apache License, Version 2.0, +// with the following copyright notice: + +// Copyright 2021 The JAX Authors. + +/* + * This file is a modified version of + * + * https://github.com/google/jax/blob/jaxlib-v0.4.28/jaxlib/cpu/lapack_kernels.cc + * + * from jaxlib-v0.4.28. + * + * See note in lapack_kernels.h for a high-level explanation of the + * modifications and the motivation for them. Specifically, the modifications + * made in this file include: + * + * 1. Used the C interfaces to the BLAS and LAPACK routines instead of the + * FORTRAN interfaces, and always use row-major matrix layout. This + * modification generally involves the following: + * - Adding the matrix layout parameter as the first argument to the BLAS/ + * LAPACK call, either `CblasRowMajor` for BLAS or `LAPACK_ROW_MAJOR` + * for LAPACK. + * - Specifying the array leading dimensions (e.g. `lda`) such that they + * are dependent upon the matrix layout, rather than hard-coding them. + * Note that these should always evaluate to the value required for + * row-major matrix layout (typically the number of columns n of the + * matrix). + * - Remove parameters used by the FORTRAN interfaces but not by the C + * interfaces, e.g. the workspace array parameters `lwork`, `rwork`, + * `iwork`, etc. + * 2. Guarded the #include of the ABSEIL `dynamic_annotations.h header by the + * `USE_ABSEIL_LIB` macro and the uses of `ABSL_ANNOTATE_MEMORY_IS_INITIALIZED`, + * since they are not needed for Catalyst. + * 3. Opportunistically improved const-correctness. + * 4. Applied Catalyst C++ code formatting. + */ + +#include "lapack_kernels.hpp" + +#include +#include +#include +#include + +#ifdef USE_ABSEIL_LIB +#include "absl/base/dynamic_annotations.h" +#endif + +namespace { + +inline int64_t catch_lapack_int_overflow(const std::string &source, int64_t value) +{ + if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) { + return value; + } + else { + if (value > std::numeric_limits::max()) { + throw std::overflow_error(source + "(=" + std::to_string(value) + + ") exceeds maximum value of jax::lapack_int"); + } + return value; + } +} + +} // namespace + +namespace jax { + +static_assert(sizeof(lapack_int) == sizeof(int32_t), "Expected LAPACK integers to be 32-bit"); + +// Trsm +// ~~~~ + +template typename RealTrsm::FnType *RealTrsm::fn = nullptr; + +template void RealTrsm::Kernel(void *out, void **data, XlaCustomCallStatus *) +{ + const int32_t left_side = *reinterpret_cast(data[0]); + const int32_t lower = *reinterpret_cast(data[1]); + const int32_t trans_a = *reinterpret_cast(data[2]); + const int32_t diag = *reinterpret_cast(data[3]); + const int m = *reinterpret_cast(data[4]); + const int n = *reinterpret_cast(data[5]); + const int batch = *reinterpret_cast(data[6]); + const T alpha = *reinterpret_cast(data[7]); + const T *a = reinterpret_cast(data[8]); + T *b = reinterpret_cast(data[9]); + + T *x = reinterpret_cast(out); + if (x != b) { + std::memcpy(x, b, + static_cast(batch) * static_cast(m) * + static_cast(n) * sizeof(T)); + } + + constexpr CBLAS_ORDER corder = CblasRowMajor; + const CBLAS_SIDE cside = left_side ? CblasLeft : CblasRight; + const CBLAS_UPLO cuplo = lower ? CblasLower : CblasUpper; + const CBLAS_TRANSPOSE ctransa = (trans_a == 1) ? CblasTrans + : (trans_a == 2) ? CblasConjTrans + : CblasNoTrans; + const CBLAS_DIAG cdiag = diag ? CblasUnit : CblasNonUnit; + const int lda = left_side ? m : n; + const int ldb = (corder == CblasColMajor) ? m : n; // Note: m if col-major, n if row-major + + const int64_t x_plus = static_cast(m) * static_cast(n); + const int64_t a_plus = static_cast(lda) * static_cast(lda); + + for (int i = 0; i < batch; ++i) { + fn(CblasRowMajor, cside, cuplo, ctransa, cdiag, m, n, alpha, a, lda, x, ldb); + x += x_plus; + a += a_plus; + } +} + +template typename ComplexTrsm::FnType *ComplexTrsm::fn = nullptr; + +template void ComplexTrsm::Kernel(void *out, void **data, XlaCustomCallStatus *) +{ + const int32_t left_side = *reinterpret_cast(data[0]); + const int32_t lower = *reinterpret_cast(data[1]); + const int32_t trans_a = *reinterpret_cast(data[2]); + const int32_t diag = *reinterpret_cast(data[3]); + const int m = *reinterpret_cast(data[4]); + const int n = *reinterpret_cast(data[5]); + const int batch = *reinterpret_cast(data[6]); + const T *alpha = reinterpret_cast(data[7]); + const T *a = reinterpret_cast(data[8]); + T *b = reinterpret_cast(data[9]); + + T *x = reinterpret_cast(out); + if (x != b) { + std::memcpy(x, b, + static_cast(batch) * static_cast(m) * + static_cast(n) * sizeof(T)); + } + + constexpr CBLAS_ORDER corder = CblasRowMajor; + const CBLAS_SIDE cside = left_side ? CblasLeft : CblasRight; + const CBLAS_UPLO cuplo = lower ? CblasLower : CblasUpper; + const CBLAS_TRANSPOSE ctransa = (trans_a == 1) ? CblasTrans + : (trans_a == 2) ? CblasConjTrans + : CblasNoTrans; + const CBLAS_DIAG cdiag = diag ? CblasUnit : CblasNonUnit; + const int lda = left_side ? m : n; + const int ldb = (corder == CblasColMajor) ? m : n; // Note: m if col-major, n if row-major + + const int64_t x_plus = static_cast(m) * static_cast(n); + const int64_t a_plus = static_cast(lda) * static_cast(lda); + + for (int i = 0; i < batch; ++i) { + fn(CblasRowMajor, cside, cuplo, ctransa, cdiag, m, n, alpha, a, lda, x, ldb); + x += x_plus; + a += a_plus; + } +} + +template struct RealTrsm; +template struct RealTrsm; +template struct ComplexTrsm>; +template struct ComplexTrsm>; + +// Getrf +// ~~~~~ + +template typename Getrf::FnType *Getrf::fn = nullptr; + +template void Getrf::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int b = *(reinterpret_cast(data[0])); + const int m = *(reinterpret_cast(data[1])); + const int n = *(reinterpret_cast(data[2])); + const T *a_in = reinterpret_cast(data[3]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + int *ipiv = reinterpret_cast(out[1]); + int *info = reinterpret_cast(out[2]); + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(m) * static_cast(n) * + sizeof(T)); + } + + constexpr int corder = LAPACK_ROW_MAJOR; + const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m; + + for (int i = 0; i < b; ++i) { + *info = fn(corder, m, n, a_out, lda, ipiv); + a_out += static_cast(m) * static_cast(n); + ipiv += std::min(m, n); + ++info; + } +} + +template struct Getrf; +template struct Getrf; +template struct Getrf>; +template struct Getrf>; + +// Geqrf +// ~~~~~ + +template typename Geqrf::FnType *Geqrf::fn = nullptr; + +template void Geqrf::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int b = *(reinterpret_cast(data[0])); + const int m = *(reinterpret_cast(data[1])); + const int n = *(reinterpret_cast(data[2])); + const int lwork = *(reinterpret_cast(data[3])); + const T *a_in = reinterpret_cast(data[4]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + T *tau = reinterpret_cast(out[1]); + int *info = reinterpret_cast(out[2]); + T *work = reinterpret_cast(out[3]); + + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(m) * static_cast(n) * + sizeof(T)); + } + + constexpr int corder = LAPACK_ROW_MAJOR; + const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m; + + for (int i = 0; i < b; ++i) { + *info = fn(LAPACK_ROW_MAJOR, m, n, a_out, lda, tau); + a_out += static_cast(m) * static_cast(n); + tau += std::min(m, n); + ++info; + } +} + +template struct Geqrf; +template struct Geqrf; +template struct Geqrf>; +template struct Geqrf>; + +// Orgqr +// ~~~~~ + +template typename Orgqr::FnType *Orgqr::fn = nullptr; + +template void Orgqr::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int b = *(reinterpret_cast(data[0])); + const int m = *(reinterpret_cast(data[1])); + const int n = *(reinterpret_cast(data[2])); + const int k = *(reinterpret_cast(data[3])); + const int lwork = *(reinterpret_cast(data[4])); + const T *a_in = reinterpret_cast(data[5]); + T *tau = reinterpret_cast(data[6]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + int *info = reinterpret_cast(out[1]); + T *work = reinterpret_cast(out[2]); + + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(m) * static_cast(n) * + sizeof(T)); + } + + constexpr int corder = LAPACK_ROW_MAJOR; + const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m; + + for (int i = 0; i < b; ++i) { + *info = fn(LAPACK_ROW_MAJOR, m, n, k, a_out, lda, tau); + a_out += static_cast(m) * static_cast(n); + tau += k; + ++info; + } +} + +template struct Orgqr; +template struct Orgqr; +template struct Orgqr>; +template struct Orgqr>; + +// Potrf +// ~~~~~ + +template typename Potrf::FnType *Potrf::fn = nullptr; + +template void Potrf::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int32_t lower = *(reinterpret_cast(data[0])); + const int b = *(reinterpret_cast(data[1])); + const int n = *(reinterpret_cast(data[2])); + const T *a_in = reinterpret_cast(data[3]); + const char uplo = lower ? 'L' : 'U'; + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + int *info = reinterpret_cast(out[1]); + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(n) * static_cast(n) * + sizeof(T)); + } + + constexpr int corder = LAPACK_ROW_MAJOR; + + for (int i = 0; i < b; ++i) { + *info = fn(corder, uplo, n, a_out, n); + a_out += static_cast(n) * static_cast(n); + ++info; + } +} + +template struct Potrf; +template struct Potrf; +template struct Potrf>; +template struct Potrf>; + +// Gesdd + +static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) +{ + if (!job_opt_compute_uv) { + return 'N'; + } + else if (!job_opt_full_matrices) { + return 'S'; + } + return 'A'; +} + +static int Gesdd_ldu(const int order, const char jobz, const int m, const int n) +{ + int ldu = 0; + if (jobz == 'N') { + ldu = 1; + } + else if (jobz == 'A') { + ldu = m; + } + else if (jobz == 'S') { + if (m >= n) { + ldu = (order == LAPACK_ROW_MAJOR) ? n : m; + } + else { + ldu = m; + } + } + return ldu; +} + +static int Gesdd_ldvt(const int order, const char jobz, const int m, const int n) +{ + int ldu = 0; + if (jobz == 'N') { + ldu = 1; + } + else if (jobz == 'A') { + ldu = n; + } + else if (jobz == 'S') { + if (m >= n) { + ldu = n; + } + else { + ldu = (order == LAPACK_ROW_MAJOR) ? n : m; + } + } + return ldu; +} + +template typename RealGesdd::FnType *RealGesdd::fn = nullptr; + +template void RealGesdd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); + const int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); + const int b = *(reinterpret_cast(data[2])); + const int m = *(reinterpret_cast(data[3])); + const int n = *(reinterpret_cast(data[4])); + const int lwork = *(reinterpret_cast(data[5])); + T *a_in = reinterpret_cast(data[6]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + T *s = reinterpret_cast(out[1]); + T *u = reinterpret_cast(out[2]); + T *vt = reinterpret_cast(out[3]); + int *info = reinterpret_cast(out[4]); + int *iwork = reinterpret_cast(out[5]); + T *work = reinterpret_cast(out[6]); + + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(m) * static_cast(n) * + sizeof(T)); + } + + const char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); + + constexpr int corder = LAPACK_ROW_MAJOR; + const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m; + const int ldu = Gesdd_ldu(corder, jobz, m, n); + const int tdu = ldu; + const int ldvt = Gesdd_ldvt(corder, jobz, m, n); + + for (int i = 0; i < b; ++i) { + *info = fn(corder, jobz, m, n, a_out, lda, s, u, ldu, vt, ldvt); + a_out += static_cast(m) * n; + s += std::min(m, n); + u += static_cast(m) * tdu; + vt += static_cast(ldvt) * n; + ++info; + } +} + +template typename ComplexGesdd::FnType *ComplexGesdd::fn = nullptr; + +template +void ComplexGesdd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); + const int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); + const int b = *(reinterpret_cast(data[2])); + const int m = *(reinterpret_cast(data[3])); + const int n = *(reinterpret_cast(data[4])); + const int lwork = *(reinterpret_cast(data[5])); + T *a_in = reinterpret_cast(data[6]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + typename T::value_type *s = reinterpret_cast(out[1]); + T *u = reinterpret_cast(out[2]); + T *vt = reinterpret_cast(out[3]); + int *info = reinterpret_cast(out[4]); + int *iwork = reinterpret_cast(out[5]); + typename T::value_type *rwork = reinterpret_cast(out[6]); + T *work = reinterpret_cast(out[7]); + + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(m) * static_cast(n) * + sizeof(T)); + } + + const char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); + + constexpr int corder = LAPACK_ROW_MAJOR; + const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m; + const int ldu = Gesdd_ldu(corder, jobz, m, n); + const int tdu = ldu; + const int ldvt = Gesdd_ldvt(corder, jobz, m, n); + + for (int i = 0; i < b; ++i) { + *info = fn(LAPACK_ROW_MAJOR, jobz, m, n, a_out, lda, s, u, ldu, vt, ldvt); + a_out += static_cast(m) * n; + s += std::min(m, n); + u += static_cast(m) * tdu; + vt += static_cast(ldvt) * n; + ++info; + } +} + +template struct RealGesdd; +template struct RealGesdd; +template struct ComplexGesdd>; +template struct ComplexGesdd>; + +// Syevd/Heevd +// ~~~~~~~~~~~ + +template typename RealSyevd::FnType *RealSyevd::fn = nullptr; + +template void RealSyevd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int32_t lower = *(reinterpret_cast(data[0])); + const int b = *(reinterpret_cast(data[1])); + const int n = *(reinterpret_cast(data[2])); + const T *a_in = reinterpret_cast(data[3]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + T *w_out = reinterpret_cast(out[1]); + int *info = reinterpret_cast(out[2]); + T *work = reinterpret_cast(out[3]); + int *iwork = reinterpret_cast(out[4]); + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(n) * static_cast(n) * + sizeof(T)); + } + + constexpr int corder = LAPACK_ROW_MAJOR; + const char jobz = 'V'; + const char uplo = lower ? 'L' : 'U'; + + for (int i = 0; i < b; ++i) { + *info = fn(corder, jobz, uplo, n, a_out, n, w_out); + a_out += static_cast(n) * n; + w_out += n; + ++info; + } +} + +template typename ComplexHeevd::FnType *ComplexHeevd::fn = nullptr; + +template +void ComplexHeevd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int32_t lower = *(reinterpret_cast(data[0])); + const int b = *(reinterpret_cast(data[1])); + const int n = *(reinterpret_cast(data[2])); + const T *a_in = reinterpret_cast(data[3]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + typename T::value_type *w_out = reinterpret_cast(out[1]); + int *info = reinterpret_cast(out[2]); + T *work = reinterpret_cast(out[3]); + typename T::value_type *rwork = reinterpret_cast(out[4]); + int *iwork = reinterpret_cast(out[5]); + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(n) * static_cast(n) * + sizeof(T)); + } + + constexpr int corder = LAPACK_ROW_MAJOR; + const char jobz = 'V'; + const char uplo = lower ? 'L' : 'U'; + + for (int i = 0; i < b; ++i) { + *info = fn(corder, jobz, uplo, n, a_out, n, w_out); + a_out += static_cast(n) * n; + w_out += n; + ++info; + } +} + +template struct RealSyevd; +template struct RealSyevd; +template struct ComplexHeevd>; +template struct ComplexHeevd>; + +// LAPACK uses a packed representation to represent a mixture of real +// eigenvectors and complex conjugate pairs. This helper unpacks the +// representation into regular complex matrices. +template +static void UnpackEigenvectors(int n, const T *im_eigenvalues, const T *packed, + std::complex *unpacked) +{ + T re, im; + int j; + j = 0; + while (j < n) { + if (im_eigenvalues[j] == 0. || std::isnan(im_eigenvalues[j])) { + for (int k = 0; k < n; ++k) { + unpacked[j * n + k] = {packed[j * n + k], 0.}; + } + ++j; + } + else { + for (int k = 0; k < n; ++k) { + re = packed[j * n + k]; + im = packed[(j + 1) * n + k]; + unpacked[j * n + k] = {re, im}; + unpacked[(j + 1) * n + k] = {re, -im}; + } + j += 2; + } + } +} + +// Geev +// ~~~~ + +template typename RealGeev::FnType *RealGeev::fn = nullptr; + +template void RealGeev::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int b = *(reinterpret_cast(data[0])); + const int n_int = *(reinterpret_cast(data[1])); + const int64_t n = n_int; + const char jobvl = *(reinterpret_cast(data[2])); + const char jobvr = *(reinterpret_cast(data[3])); + + const T *a_in = reinterpret_cast(data[4]); + + void **out = reinterpret_cast(out_tuple); + T *a_work = reinterpret_cast(out[0]); + T *vl_work = reinterpret_cast(out[1]); + T *vr_work = reinterpret_cast(out[2]); + + T *wr_out = reinterpret_cast(out[3]); + T *wi_out = reinterpret_cast(out[4]); + std::complex *vl_out = reinterpret_cast *>(out[5]); + std::complex *vr_out = reinterpret_cast *>(out[6]); + int *info = reinterpret_cast(out[7]); + + constexpr int corder = LAPACK_ROW_MAJOR; + + // TODO(phawkins): preallocate workspace using XLA. + *info = fn(corder, jobvl, jobvr, n_int, a_work, n_int, wr_out, wi_out, vl_work, n_int, vr_work, + n_int); + + auto is_finite = [](T *a_work, int64_t n) { + for (int64_t j = 0; j < n; ++j) { + for (int64_t k = 0; k < n; ++k) { + if (!std::isfinite(a_work[j * n + k])) { + return false; + } + } + } + return true; + }; + for (int i = 0; i < b; ++i) { + size_t a_size = n * n * sizeof(T); + std::memcpy(a_work, a_in, a_size); + if (is_finite(a_work, n)) { + *info = fn(corder, jobvl, jobvr, n_int, a_work, n_int, wr_out, wi_out, vl_work, n_int, + vr_work, n_int); +#ifdef USE_ABSEIL_LIB + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_work, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_work, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info, sizeof(int)); +#endif + if (info[0] == 0) { + UnpackEigenvectors(n, wi_out, vl_work, vl_out); + UnpackEigenvectors(n, wi_out, vr_work, vr_out); + } + } + else { + *info = -4; + } + a_in += n * n; + wr_out += n; + wi_out += n; + vl_out += n * n; + vr_out += n * n; + ++info; + } +} + +template typename ComplexGeev::FnType *ComplexGeev::fn = nullptr; + +template +void ComplexGeev::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int b = *(reinterpret_cast(data[0])); + const int n_int = *(reinterpret_cast(data[1])); + const int64_t n = n_int; + const char jobvl = *(reinterpret_cast(data[2])); + const char jobvr = *(reinterpret_cast(data[3])); + + const T *a_in = reinterpret_cast(data[4]); + + void **out = reinterpret_cast(out_tuple); + T *a_work = reinterpret_cast(out[0]); + typename T::value_type *r_work = reinterpret_cast(out[1]); + + T *w_out = reinterpret_cast(out[2]); + T *vl_out = reinterpret_cast(out[3]); + T *vr_out = reinterpret_cast(out[4]); + int *info = reinterpret_cast(out[5]); + + constexpr int corder = LAPACK_ROW_MAJOR; + + // TODO(phawkins): preallocate workspace using XLA. + *info = fn(corder, jobvl, jobvr, n_int, a_work, n_int, w_out, vl_out, n_int, vr_out, n_int); + + auto is_finite = [](T *a_work, int64_t n) { + for (int64_t j = 0; j < n; ++j) { + for (int64_t k = 0; k < n; ++k) { + T v = a_work[j * n + k]; + if (!std::isfinite(v.real()) || !std::isfinite(v.imag())) { + return false; + } + } + } + return true; + }; + + for (int i = 0; i < b; ++i) { + size_t a_size = n * n * sizeof(T); + std::memcpy(a_work, a_in, a_size); + if (is_finite(a_work, n)) { + *info = + fn(corder, jobvl, jobvr, n_int, a_work, n_int, w_out, vl_out, n_int, vr_out, n_int); +#ifdef USE_ABSEIL_LIB + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_work, a_size); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vl_out, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vr_out, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); +#endif + } + else { + *info = -4; + } + a_in += n * n; + w_out += n; + vl_out += n * n; + vr_out += n * n; + info += 1; + } +} + +template struct RealGeev; +template struct RealGeev; +template struct ComplexGeev>; +template struct ComplexGeev>; + +// Gees +// ~~~~ + +template typename RealGees::FnType *RealGees::fn = nullptr; + +template void RealGees::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int b = *(reinterpret_cast(data[0])); + const int n_int = *(reinterpret_cast(data[1])); + const int64_t n = n_int; + const char jobvs = *(reinterpret_cast(data[2])); + const char sort = *(reinterpret_cast(data[3])); + + const T *a_in = reinterpret_cast(data[4]); + + // bool* select (T, T) = reinterpret_cast(data[5]); + bool (*select)(T, T) = nullptr; + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + + T *wr_out = reinterpret_cast(out[1]); + T *wi_out = reinterpret_cast(out[2]); + T *vs_out = reinterpret_cast(out[3]); + int *sdim_out = reinterpret_cast(out[4]); + int *info = reinterpret_cast(out[5]); + + constexpr int corder = LAPACK_ROW_MAJOR; + + *info = fn(corder, jobvs, sort, select, n_int, a_out, n_int, sdim_out, wr_out, wi_out, vs_out, + n_int); + + size_t a_size = static_cast(n) * static_cast(n) * sizeof(T); + if (a_out != a_in) { + std::memcpy(a_out, a_in, static_cast(b) * a_size); + } + + for (int i = 0; i < b; ++i) { + *info = fn(corder, jobvs, sort, select, n_int, a_out, n_int, sdim_out, wr_out, wi_out, + vs_out, n_int); +#ifdef USE_ABSEIL_LIB + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(a_out, a_size); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wr_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(wi_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); +#endif + + a_in += n * n; + a_out += n * n; + wr_out += n; + wi_out += n; + vs_out += n * n; + ++sdim_out; + ++info; + } +} + +template typename ComplexGees::FnType *ComplexGees::fn = nullptr; + +template +void ComplexGees::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int b = *(reinterpret_cast(data[0])); + const int n_int = *(reinterpret_cast(data[1])); + const int64_t n = n_int; + const char jobvs = *(reinterpret_cast(data[2])); + const char sort = *(reinterpret_cast(data[3])); + + const T *a_in = reinterpret_cast(data[4]); + + // bool* select (T, T) = reinterpret_cast(data[5]); + bool (*select)(T) = nullptr; + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + typename T::value_type *r_work = reinterpret_cast(out[1]); + T *w_out = reinterpret_cast(out[2]); + T *vs_out = reinterpret_cast(out[3]); + int *sdim_out = reinterpret_cast(out[4]); + int *info = reinterpret_cast(out[5]); + + constexpr int corder = LAPACK_ROW_MAJOR; + + *info = fn(corder, jobvs, sort, select, n_int, a_out, n_int, sdim_out, w_out, vs_out, n_int); + + if (a_out != a_in) { + std::memcpy(a_out, a_in, + static_cast(b) * static_cast(n) * static_cast(n) * + sizeof(T)); + } + + for (int i = 0; i < b; ++i) { + *info = + fn(corder, jobvs, sort, select, n_int, a_out, n_int, sdim_out, w_out, vs_out, n_int); +#ifdef USE_ABSEIL_LIB + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(w_out, sizeof(T) * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(vs_out, sizeof(T) * n * n); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_out, sizeof(int)); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(sdim_out, sizeof(int)); +#endif + + a_in += n * n; + a_out += n * n; + w_out += n; + vs_out += n * n; + ++info; + ++sdim_out; + } +} + +template struct RealGees; +template struct RealGees; +template struct ComplexGees>; +template struct ComplexGees>; + +// Gehrd + +template typename Gehrd::FnType *Gehrd::fn = nullptr; + +template void Gehrd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int32_t n = *reinterpret_cast(data[0]); + const int32_t ilo = *reinterpret_cast(data[1]); + const int32_t ihi = *reinterpret_cast(data[2]); + const int32_t lda = *reinterpret_cast(data[3]); + const int32_t batch = *reinterpret_cast(data[4]); + const int32_t lwork = *reinterpret_cast(data[5]); + T *a = reinterpret_cast(data[6]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + T *tau = reinterpret_cast(out[1]); + int *info = reinterpret_cast(out[2]); + T *work = reinterpret_cast(out[3]); + + if (a_out != a) { + std::memcpy(a_out, a, + static_cast(batch) * static_cast(n) * + static_cast(n) * sizeof(T)); + } + + const int64_t a_plus = static_cast(lda) * static_cast(n); + + constexpr int corder = LAPACK_ROW_MAJOR; + + for (int i = 0; i < batch; ++i) { + *info = fn(corder, n, ilo, ihi, a_out, lda, tau); + a_out += a_plus; + tau += n - 1; + ++info; + } +} + +template struct Gehrd; +template struct Gehrd; +template struct Gehrd>; +template struct Gehrd>; + +// Sytrd +// ~~~~~ + +template typename Sytrd::FnType *Sytrd::fn = nullptr; + +template void Sytrd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *) +{ + const int32_t n = *reinterpret_cast(data[0]); + const int32_t lower = *reinterpret_cast(data[1]); + const int32_t lda = *reinterpret_cast(data[2]); + const int32_t batch = *reinterpret_cast(data[3]); + const int32_t lwork = *reinterpret_cast(data[4]); + T *a = reinterpret_cast(data[5]); + + void **out = reinterpret_cast(out_tuple); + T *a_out = reinterpret_cast(out[0]); + typedef typename real_type::type Real; + Real *d = reinterpret_cast(out[1]); + Real *e = reinterpret_cast(out[2]); + T *tau = reinterpret_cast(out[3]); + int *info = reinterpret_cast(out[4]); + T *work = reinterpret_cast(out[5]); + + if (a_out != a) { + std::memcpy(a_out, a, + static_cast(batch) * static_cast(n) * + static_cast(n) * sizeof(T)); + } + + constexpr int corder = LAPACK_ROW_MAJOR; + const char cuplo = lower ? 'L' : 'U'; + + const int64_t a_plus = static_cast(lda) * static_cast(n); + + for (int i = 0; i < batch; ++i) { + *info = fn(corder, cuplo, n, a_out, lda, d, e, tau); + a_out += a_plus; + d += n; + e += n - 1; + tau += n - 1; + ++info; + } +} + +template struct Sytrd; +template struct Sytrd; +template struct Sytrd>; +template struct Sytrd>; + +} // namespace jax diff --git a/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.hpp b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.hpp new file mode 100644 index 0000000000..9ec0aa3cfc --- /dev/null +++ b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.hpp @@ -0,0 +1,223 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file has been modified from its original form in the JAX project at +// https://github.com/google/jax released under the Apache License, Version 2.0, +// with the following copyright notice: + +// Copyright 2021 The JAX Authors. + +/* + * This file is a modified version of + * + * https://github.com/google/jax/blob/jaxlib-v0.4.28/jaxlib/cpu/lapack_kernels.h + * + * from jaxlib-v0.4.28. + * + * The LAPACK kernels below have been modified from their original form in JAX + * to use the respective C interfaces to the underlying BLAS and LAPACK + * routines, rather than the FORTRAN interfaces that JAX uses, for compatibility + * with Catalyst. Recall that the FORTRAN interfaces require arrays and matrices + * in column-major order, while the C interfaces allow row-major order, which is + * required for Catalyst. + * + * In addition, the following modifications have been made: + * + * 1. Guarded the #include of the XLA `custom_call_status.h` header by the + * `USE_XLA_LIB` macro; simply declared the `XlaCustomCallStatus` type + * instead, since it is not explicitly used. + * 2. Copied the BLAS and LAPACK enums and option codes (e.g. `CBLAS_ORDER` + * and `LAPACK_ROW_MAJOR`) needed for the C interfaces. + * 3. Applied Catalyst C++ code formatting. + */ + +#ifndef JAXLIB_CPU_LAPACK_KERNELS_H_ +#define JAXLIB_CPU_LAPACK_KERNELS_H_ + +#include +#include + +#ifdef USE_XLA_LIB +#include "xla/service/custom_call_status.h" +#else +typedef struct XlaCustomCallStatus_ XlaCustomCallStatus; +#endif + +// Underlying function pointers (e.g., Trsm::Fn) are initialized either +// by the pybind wrapper that links them to an existing SciPy lapack instance, +// or using the lapack_kernels_strong.cc static initialization to link them +// directly to lapack for use in a pure C++ context. + +namespace jax { + +// Copied from cblas.h +typedef enum CBLAS_ORDER { CblasRowMajor = 101, CblasColMajor = 102 } CBLAS_ORDER; +typedef enum CBLAS_TRANSPOSE { + CblasNoTrans = 111, + CblasTrans = 112, + CblasConjTrans = 113, + CblasConjNoTrans = 114 +} CBLAS_TRANSPOSE; +typedef enum CBLAS_UPLO { CblasUpper = 121, CblasLower = 122 } CBLAS_UPLO; +typedef enum CBLAS_DIAG { CblasNonUnit = 131, CblasUnit = 132 } CBLAS_DIAG; +typedef enum CBLAS_SIDE { CblasLeft = 141, CblasRight = 142 } CBLAS_SIDE; +typedef CBLAS_ORDER CBLAS_LAYOUT; + +typedef int lapack_int; + +// Copied from lapacke.h +#define LAPACK_ROW_MAJOR 101 +#define LAPACK_COL_MAJOR 102 + +// trsm: Solves a triangular matrix equation. +template struct RealTrsm { + using FnType = void(const CBLAS_LAYOUT layout, const CBLAS_SIDE Side, const CBLAS_UPLO Uplo, + const CBLAS_TRANSPOSE TransA, const CBLAS_DIAG Diag, const int M, + const int N, const T alpha, const T *A, const int lda, T *B, const int ldb); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +template struct ComplexTrsm { + using FnType = void(const CBLAS_LAYOUT layout, const CBLAS_SIDE Side, const CBLAS_UPLO Uplo, + const CBLAS_TRANSPOSE TransA, const CBLAS_DIAG Diag, const int M, + const int N, const void *alpha, const void *A, const int lda, void *B, + const int ldb); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// getrf: Computes the LU factorization of a general m-by-n matrix +template struct Getrf { + using FnType = lapack_int(int matrix_layout, lapack_int m, lapack_int n, T *a, lapack_int lda, + lapack_int *ipiv); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// geqrf: Computes the QR factorization of a general m-by-n matrix. +template struct Geqrf { + using FnType = lapack_int(int matrix_layout, lapack_int m, lapack_int n, T *a, lapack_int lda, + T *tau); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// orgqr: Generates the real orthogonal matrix Q of the QR factorization formed by geqrf +template struct Orgqr { + using FnType = lapack_int(int matrix_layout, lapack_int m, lapack_int n, lapack_int k, T *a, + lapack_int lda, const T *tau); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// potrf: Computes the Cholesky factorization of a symmetric (Hermitian) positive-definite matrix +template struct Potrf { + using FnType = lapack_int(int matrix_layout, char uplo, lapack_int n, T *a, lapack_int lda); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// gesdd: computes the singular value decomposition (SVD) of an m-by-n matrix +template struct RealGesdd { + using FnType = lapack_int(int matrix_layout, char jobz, lapack_int m, lapack_int n, T *a, + lapack_int lda, T *s, T *u, lapack_int ldu, T *vt, lapack_int ldvt); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +template struct ComplexGesdd { + using FnType = lapack_int(int matrix_layout, char jobz, lapack_int m, lapack_int n, T *a, + lapack_int lda, typename T::value_type *s, T *u, lapack_int ldu, + T *vt, lapack_int ldvt); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// syevd: Computes all eigenvalues and, optionally, all eigenvectors of a real symmetric matrix +template struct RealSyevd { + using FnType = lapack_int(int matrix_layout, char jobz, char uplo, lapack_int n, T *a, + lapack_int lda, T *w); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// heevd: Computes all eigenvalues and, optionally, all eigenvectors of a complex Hermitian matrix +template struct ComplexHeevd { + using FnType = lapack_int(int matrix_layout, char jobz, char uplo, lapack_int n, T *a, + lapack_int lda, typename T::value_type *w); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// geev: Computes the eigenvalues and left and right eigenvectors of a general matrix +template struct RealGeev { + using FnType = lapack_int(int matrix_layout, char jobvl, char jobvr, lapack_int n, T *a, + lapack_int lda, T *wr, T *wi, T *vl, lapack_int ldvl, T *vr, + lapack_int ldvr); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +template struct ComplexGeev { + using FnType = lapack_int(int matrix_layout, char jobvl, char jobvr, lapack_int n, T *a, + lapack_int lda, T *w, T *vl, lapack_int ldvl, T *vr, lapack_int ldvr); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// gees: Computes the eigenvalues and Schur factorization of a general matrix +template struct RealGees { + using FnType = lapack_int(int matrix_layout, char jobvs, char sort, bool (*select)(T, T), + lapack_int n, T *a, lapack_int lda, lapack_int *sdim, T *wr, T *wi, + T *vs, lapack_int ldvs); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +template struct ComplexGees { + using FnType = lapack_int(int matrix_layout, char jobvs, char sort, bool (*select)(T), + lapack_int n, T *a, lapack_int lda, lapack_int *sdim, T *w, T *vs, + lapack_int ldvs); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +// Gehrd: Reduces a non-symmetric square matrix to upper Hessenberg form +template struct Gehrd { + using FnType = lapack_int(int matrix_layout, lapack_int n, lapack_int ilo, lapack_int ihi, T *a, + lapack_int lda, T *tau); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +template struct real_type { + typedef T type; +}; +template struct real_type> { + typedef T type; +}; + +// Sytrd/Hetrd: Reduces a symmetric (Hermitian) square matrix to tridiagonal form +template struct Sytrd { + using FnType = lapack_int(int matrix_layout, char uplo, lapack_int n, T *a, lapack_int lda, + typename real_type::type *d, typename real_type::type *e, + T *tau); + static FnType *fn; + static void Kernel(void *out, void **data, XlaCustomCallStatus *); +}; + +} // namespace jax + +#endif // JAXLIB_CPU_LAPACK_KERNELS_H_ diff --git a/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp new file mode 100644 index 0000000000..b3875101be --- /dev/null +++ b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp @@ -0,0 +1,162 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file has been modified from its original form in the JAX project at +// https://github.com/google/jax released under the Apache License, Version 2.0, +// with the following copyright notice: + +// Copyright 2021 The JAX Authors. + +/* + * This file is a modified version of + * + * https://github.com/google/jax/blob/jaxlib-v0.4.28/jaxlib/cpu/lapack_kernels_using_lapack.cc + * + * from jaxlib-v0.4.28. + * + * See note in lapack_kernels.h for a high-level explanation of the + * modifications and the motivation for them. Specifically, the names of the + * BLAS and LAPACK routine symbols have been changed from the FORTRAN interfaces + * to the equivalent C interfaces. For example, the `dtrsm` BLAS routine has + * been changed from `dtrsm_` to `cblas_dtrsm`, and the `dgetrf` LAPACK routine + * has been changed from `dgetrf_` to `LAPACKE_dgetrf`. + */ + +#include "lapack_kernels.hpp" + +// From a Python binary, JAX obtains its LAPACK/BLAS kernels from Scipy, but +// a C++ user should link against LAPACK directly. This is needed when using +// JAX-generated HLO from C++. + +extern "C" { + +jax::RealTrsm::FnType cblas_strsm; +jax::RealTrsm::FnType cblas_dtrsm; +jax::ComplexTrsm>::FnType cblas_ctrsm; +jax::ComplexTrsm>::FnType cblas_ztrsm; + +jax::Getrf::FnType LAPACKE_sgetrf; +jax::Getrf::FnType LAPACKE_dgetrf; +jax::Getrf>::FnType LAPACKE_cgetrf; +jax::Getrf>::FnType LAPACKE_zgetrf; + +jax::Geqrf::FnType LAPACKE_sgeqrf; +jax::Geqrf::FnType LAPACKE_dgeqrf; +jax::Geqrf>::FnType LAPACKE_cgeqrf; +jax::Geqrf>::FnType LAPACKE_zgeqrf; + +jax::Orgqr::FnType LAPACKE_sorgqr; +jax::Orgqr::FnType LAPACKE_dorgqr; +jax::Orgqr>::FnType LAPACKE_cungqr; +jax::Orgqr>::FnType LAPACKE_zungqr; + +jax::Potrf::FnType LAPACKE_spotrf; +jax::Potrf::FnType LAPACKE_dpotrf; +jax::Potrf>::FnType LAPACKE_cpotrf; +jax::Potrf>::FnType LAPACKE_zpotrf; + +jax::RealGesdd::FnType LAPACKE_sgesdd; +jax::RealGesdd::FnType LAPACKE_dgesdd; +jax::ComplexGesdd>::FnType LAPACKE_cgesdd; +jax::ComplexGesdd>::FnType LAPACKE_zgesdd; + +jax::RealSyevd::FnType LAPACKE_ssyevd; +jax::RealSyevd::FnType LAPACKE_dsyevd; +jax::ComplexHeevd>::FnType LAPACKE_cheevd; +jax::ComplexHeevd>::FnType LAPACKE_zheevd; + +jax::RealGeev::FnType LAPACKE_sgeev; +jax::RealGeev::FnType LAPACKE_dgeev; +jax::ComplexGeev>::FnType LAPACKE_cgeev; +jax::ComplexGeev>::FnType LAPACKE_zgeev; + +jax::RealGees::FnType LAPACKE_sgees; +jax::RealGees::FnType LAPACKE_dgees; +jax::ComplexGees>::FnType LAPACKE_cgees; +jax::ComplexGees>::FnType LAPACKE_zgees; + +jax::Gehrd::FnType LAPACKE_sgehrd; +jax::Gehrd::FnType LAPACKE_dgehrd; +jax::Gehrd>::FnType LAPACKE_cgehrd; +jax::Gehrd>::FnType LAPACKE_zgehrd; + +jax::Sytrd::FnType LAPACKE_ssytrd; +jax::Sytrd::FnType LAPACKE_dsytrd; +jax::Sytrd>::FnType LAPACKE_chetrd; +jax::Sytrd>::FnType LAPACKE_zhetrd; + +} // extern "C" + +namespace jax { + +static auto init = []() -> int { + RealTrsm::fn = cblas_strsm; + RealTrsm::fn = cblas_dtrsm; + ComplexTrsm>::fn = cblas_ctrsm; + ComplexTrsm>::fn = cblas_ztrsm; + + Getrf::fn = LAPACKE_sgetrf; + Getrf::fn = LAPACKE_dgetrf; + Getrf>::fn = LAPACKE_cgetrf; + Getrf>::fn = LAPACKE_zgetrf; + + Geqrf::fn = LAPACKE_sgeqrf; + Geqrf::fn = LAPACKE_dgeqrf; + Geqrf>::fn = LAPACKE_cgeqrf; + Geqrf>::fn = LAPACKE_zgeqrf; + + Orgqr::fn = LAPACKE_sorgqr; + Orgqr::fn = LAPACKE_dorgqr; + Orgqr>::fn = LAPACKE_cungqr; + Orgqr>::fn = LAPACKE_zungqr; + + Potrf::fn = LAPACKE_spotrf; + Potrf::fn = LAPACKE_dpotrf; + Potrf>::fn = LAPACKE_cpotrf; + Potrf>::fn = LAPACKE_zpotrf; + + RealGesdd::fn = LAPACKE_sgesdd; + RealGesdd::fn = LAPACKE_dgesdd; + ComplexGesdd>::fn = LAPACKE_cgesdd; + ComplexGesdd>::fn = LAPACKE_zgesdd; + + RealSyevd::fn = LAPACKE_ssyevd; + RealSyevd::fn = LAPACKE_dsyevd; + ComplexHeevd>::fn = LAPACKE_cheevd; + ComplexHeevd>::fn = LAPACKE_zheevd; + + RealGeev::fn = LAPACKE_sgeev; + RealGeev::fn = LAPACKE_dgeev; + ComplexGeev>::fn = LAPACKE_cgeev; + ComplexGeev>::fn = LAPACKE_zgeev; + + RealGees::fn = LAPACKE_sgees; + RealGees::fn = LAPACKE_dgees; + ComplexGees>::fn = LAPACKE_cgees; + ComplexGees>::fn = LAPACKE_zgees; + + Gehrd::fn = LAPACKE_sgehrd; + Gehrd::fn = LAPACKE_dgehrd; + Gehrd>::fn = LAPACKE_cgehrd; + Gehrd>::fn = LAPACKE_zgehrd; + + Sytrd::fn = LAPACKE_ssytrd; + Sytrd::fn = LAPACKE_dsytrd; + Sytrd>::fn = LAPACKE_chetrd; + Sytrd>::fn = LAPACKE_zhetrd; + + return 0; +}(); + +} // namespace jax diff --git a/frontend/catalyst/utils/libcustom_calls.cpp b/frontend/catalyst/utils/libcustom_calls.cpp index 32c72a76c9..5cdd95299e 100644 --- a/frontend/catalyst/utils/libcustom_calls.cpp +++ b/frontend/catalyst/utils/libcustom_calls.cpp @@ -12,35 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include -#include -#include -#include -#include -#include -#include -#include - -namespace { - -typedef int lapack_int; -typedef std::complex dComplex; - -static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices) -{ - if (!job_opt_compute_uv) { - return 'N'; - } - else if (!job_opt_full_matrices) { - return 'S'; - } - return 'A'; -} -} // namespace +#include "jax_cpu_lapack_kernels/lapack_kernels.hpp" -extern "C" { +#ifdef DEBUG +#include +#define DEBUG_MSG(str) std::cout << "DEBUG: " << str << std::endl; +#else +#define DEBUG_MSG(str) // No operation +#endif // MemRef type struct EncodedMemref { @@ -49,316 +30,83 @@ struct EncodedMemref { int8_t dtype; }; -void dgesdd_(char *jobz, lapack_int *m, lapack_int *n, double *a, lapack_int *lda, double *s, - double *u, lapack_int *ldu, double *vt, lapack_int *ldvt, double *work, - lapack_int *lwork, lapack_int *iwork, lapack_int *info); - -void dsyevd_(char *jobz, char *uplo, lapack_int *n, double *a, int *lda, double *w, double *work, - lapack_int *lwork, lapack_int *iwork, lapack_int *liwork, lapack_int *info); - -void dtrsm_(char *side, char *uplo, char *transa, char *diag, lapack_int *m, lapack_int *n, - double *alpha, double *a, lapack_int *lda, double *b, lapack_int *ldb); - -void ztrsm_(char *side, char *uplo, char *transa, char *diag, lapack_int *m, lapack_int *n, - dComplex *alpha, dComplex *a, lapack_int *lda, dComplex *b, lapack_int *ldb); - -void dgetrf_(lapack_int *m, lapack_int *n, double *a, lapack_int *lda, lapack_int *ipiv, - lapack_int *info); - -void zgetrf_(lapack_int *m, lapack_int *n, dComplex *a, lapack_int *lda, lapack_int *ipiv, - lapack_int *info); - -// Wrapper to call various blas core routine. Currently includes: -// - the SVD solver `dgesdd_` -// - the eigen vectors/values computation `dsyevd_` -// - the double (complex) triangular matrix equation solver `dtrsm_` (`ztrsm_`) -// - the double (complex) LU factorization `dgetrf_` (`zgetrf_`) -// from Lapack: -// https://github.com/google/jax/blob/main/jaxlib/cpu/lapack_kernels.cc released under the Apache -// License, Version 2.0, with the following copyright notice: - -// Copyright 2021 The JAX Authors. -void lapack_dgesdd(void **dataEncoded, void **resultsEncoded) -{ - std::vector data; - for (size_t i = 0; i < 7; ++i) { - auto encodedMemref = *(reinterpret_cast(dataEncoded[i])); - data.push_back(encodedMemref.data_aligned); - } - - std::vector out; - for (size_t i = 0; i < 7; ++i) { - auto encodedMemref = *(reinterpret_cast(resultsEncoded[i])); - out.push_back(encodedMemref.data_aligned); - } - - int32_t job_opt_full_matrices = *(reinterpret_cast(data[0])); - int32_t job_opt_compute_uv = *(reinterpret_cast(data[1])); - int b = *(reinterpret_cast(data[2])); - int m = *(reinterpret_cast(data[3])); - int n = *(reinterpret_cast(data[4])); - int lwork = *(reinterpret_cast(data[5])); - double *a_in = reinterpret_cast(data[6]); - - double *a_out = reinterpret_cast(out[0]); - double *s = reinterpret_cast(out[1]); - // U and vt are switched to produce the right results... - double *vt = reinterpret_cast(out[2]); - double *u = reinterpret_cast(out[3]); - - int *info = reinterpret_cast(out[4]); - int *iwork = reinterpret_cast(out[5]); - double *work = reinterpret_cast(out[6]); - - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * static_cast(n) * - sizeof(double)); - } - - char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices); - - int lda = m; - int ldu = m; - int tdu = job_opt_full_matrices ? m : std::min(m, n); - int ldvt = job_opt_full_matrices ? n : std::min(m, n); - - for (int i = 0; i < b; ++i) { - dgesdd_(&jobz, &m, &n, a_out, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info); - a_out += static_cast(m) * n; - s += std::min(m, n); - u += static_cast(m) * tdu; - vt += static_cast(ldvt) * n; - ++info; - } -} - -// Copyright 2021 The JAX Authors. -void lapack_dsyevd(void **dataEncoded, void **resultsEncoded) -{ - std::vector data; - for (size_t i = 0; i < 4; ++i) { - auto encodedMemref = *(reinterpret_cast(dataEncoded[i])); - data.push_back(encodedMemref.data_aligned); - } - - std::vector out; - for (size_t i = 0; i < 5; ++i) { - auto encodedMemref = *(reinterpret_cast(resultsEncoded[i])); - out.push_back(encodedMemref.data_aligned); - } - - int32_t lower = *(reinterpret_cast(data[0])); - int b = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const double *a_in = reinterpret_cast(data[3]); - - double *a_out = reinterpret_cast(out[0]); - double *w_out = reinterpret_cast(out[1]); - int *info_out = reinterpret_cast(out[2]); - double *work = reinterpret_cast(out[3]); - int *iwork = reinterpret_cast(out[4]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(n) * static_cast(n) * - sizeof(double)); - } - - char jobz = 'V'; - char uplo = lower ? 'L' : 'U'; - - lapack_int lwork = - std::min(std::numeric_limits::max(), 1 + 6 * n + 2 * n * n); - lapack_int liwork = std::min(std::numeric_limits::max(), 3 + 5 * n); - for (int i = 0; i < b; ++i) { - dsyevd_(&jobz, &uplo, &n, a_out, &n, w_out, work, &lwork, iwork, &liwork, info_out); - a_out += static_cast(n) * n; - w_out += n; - ++info_out; - } -} - -// Copyright 2021 The JAX Authors. -void blas_dtrsm(void **dataEncoded, void **resultsEncoded) -{ - std::vector data; - for (size_t i = 0; i < 10; ++i) { - auto encodedMemref = *(reinterpret_cast(dataEncoded[i])); - data.push_back(encodedMemref.data_aligned); - } - - std::vector out; - for (size_t i = 0; i < 1; ++i) { - auto encodedMemref = *(reinterpret_cast(resultsEncoded[i])); - out.push_back(encodedMemref.data_aligned); - } - - int32_t left_side = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t trans_a = *reinterpret_cast(data[2]); - int32_t diag = *reinterpret_cast(data[3]); - int m = *reinterpret_cast(data[4]); - int n = *reinterpret_cast(data[5]); - int batch = *reinterpret_cast(data[6]); - double *alpha = reinterpret_cast(data[7]); - double *a = reinterpret_cast(data[8]); - double *b = reinterpret_cast(data[9]); - - double *x = reinterpret_cast(out[0]); - if (x != b) { - std::memcpy(x, b, - static_cast(batch) * static_cast(m) * - static_cast(n) * sizeof(double)); - } - - char cside = left_side ? 'L' : 'R'; - char cuplo = lower ? 'L' : 'U'; - char ctransa = 'N'; - if (trans_a == 1) { - ctransa = 'T'; - } - else if (trans_a == 2) { - ctransa = 'C'; - } - char cdiag = diag ? 'U' : 'N'; - int lda = left_side ? m : n; - int ldb = m; - - int64_t x_plus = static_cast(m) * static_cast(n); - int64_t a_plus = static_cast(lda) * static_cast(lda); - - for (int i = 0; i < batch; ++i) { - dtrsm_(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb); - x += x_plus; - a += a_plus; - } -} - -// Copyright 2021 The JAX Authors. -void blas_ztrsm(void **dataEncoded, void **resultsEncoded) -{ - std::vector data; - for (size_t i = 0; i < 10; ++i) { - auto encodedMemref = *(reinterpret_cast(dataEncoded[i])); - data.push_back(encodedMemref.data_aligned); - } - - std::vector out; - for (size_t i = 0; i < 1; ++i) { - auto encodedMemref = *(reinterpret_cast(resultsEncoded[i])); - out.push_back(encodedMemref.data_aligned); - } - - int32_t left_side = *reinterpret_cast(data[0]); - int32_t lower = *reinterpret_cast(data[1]); - int32_t trans_a = *reinterpret_cast(data[2]); - int32_t diag = *reinterpret_cast(data[3]); - int m = *reinterpret_cast(data[4]); - int n = *reinterpret_cast(data[5]); - int batch = *reinterpret_cast(data[6]); - dComplex *alpha = reinterpret_cast(data[7]); - dComplex *a = reinterpret_cast(data[8]); - dComplex *b = reinterpret_cast(data[9]); - - dComplex *x = reinterpret_cast(out[0]); - if (x != b) { - std::memcpy(x, b, - static_cast(batch) * static_cast(m) * - static_cast(n) * sizeof(dComplex)); - } - - char cside = left_side ? 'L' : 'R'; - char cuplo = lower ? 'L' : 'U'; - char ctransa = 'N'; - if (trans_a == 1) { - ctransa = 'T'; - } - else if (trans_a == 2) { - ctransa = 'C'; - } - char cdiag = diag ? 'U' : 'N'; - int lda = left_side ? m : n; - int ldb = m; - - int64_t x_plus = static_cast(m) * static_cast(n); - int64_t a_plus = static_cast(lda) * static_cast(lda); - - for (int i = 0; i < batch; ++i) { - ztrsm_(&cside, &cuplo, &ctransa, &cdiag, &m, &n, alpha, a, &lda, x, &ldb); - x += x_plus; - a += a_plus; - } -} - -// Copyright 2021 The JAX Authors. -void lapack_dgetrf(void **dataEncoded, void **resultsEncoded) -{ - std::vector data; - for (size_t i = 0; i < 4; ++i) { - auto encodedMemref = *(reinterpret_cast(dataEncoded[i])); - data.push_back(encodedMemref.data_aligned); - } - - std::vector out; - for (size_t i = 0; i < 3; ++i) { - auto encodedMemref = *(reinterpret_cast(resultsEncoded[i])); - out.push_back(encodedMemref.data_aligned); - } - - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const double *a_in = reinterpret_cast(data[3]); - - double *a_out = reinterpret_cast(out[0]); - int *ipiv = reinterpret_cast(out[1]); - int *info = reinterpret_cast(out[2]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * static_cast(n) * - sizeof(double)); - } - for (int i = 0; i < b; ++i) { - dgetrf_(&m, &n, a_out, &m, ipiv, info); - a_out += static_cast(m) * static_cast(n); - ipiv += std::min(m, n); - ++info; - } -} - -// Copyright 2021 The JAX Authors. -void lapack_zgetrf(void **dataEncoded, void **resultsEncoded) -{ - std::vector data; - for (size_t i = 0; i < 4; ++i) { - auto encodedMemref = *(reinterpret_cast(dataEncoded[i])); - data.push_back(encodedMemref.data_aligned); - } - - std::vector out; - for (size_t i = 0; i < 3; ++i) { - auto encodedMemref = *(reinterpret_cast(resultsEncoded[i])); - out.push_back(encodedMemref.data_aligned); - } - - int b = *(reinterpret_cast(data[0])); - int m = *(reinterpret_cast(data[1])); - int n = *(reinterpret_cast(data[2])); - const dComplex *a_in = reinterpret_cast(data[3]); - - dComplex *a_out = reinterpret_cast(out[0]); - int *ipiv = reinterpret_cast(out[1]); - int *info = reinterpret_cast(out[2]); - if (a_out != a_in) { - std::memcpy(a_out, a_in, - static_cast(b) * static_cast(m) * static_cast(n) * - sizeof(dComplex)); - } - for (int i = 0; i < b; ++i) { - zgetrf_(&m, &n, a_out, &m, ipiv, info); - a_out += static_cast(m) * static_cast(n); - ipiv += std::min(m, n); - ++info; - } -} -} +#define DEFINE_LAPACK_FUNC(FUNC_NAME, DATA_SIZE, OUT_SIZE, KERNEL) \ + extern "C" { \ + void FUNC_NAME(void **dataEncoded, void **resultsEncoded) \ + { \ + DEBUG_MSG(#FUNC_NAME); \ + void *data[DATA_SIZE]; \ + for (size_t i = 0; i < DATA_SIZE; ++i) { \ + auto encodedMemref = *(reinterpret_cast(dataEncoded[i])); \ + data[i] = encodedMemref.data_aligned; \ + } \ + \ + if (OUT_SIZE > 1) { \ + void *out[OUT_SIZE]; \ + for (size_t i = 0; i < OUT_SIZE; ++i) { \ + auto encodedMemref = *(reinterpret_cast(resultsEncoded[i])); \ + out[i] = encodedMemref.data_aligned; \ + } \ + KERNEL::Kernel(out, data, nullptr); \ + } \ + else { \ + auto encodedMemref = *(reinterpret_cast(resultsEncoded[0])); \ + KERNEL::Kernel(encodedMemref.data_aligned, data, nullptr); \ + } \ + } \ + } + +DEFINE_LAPACK_FUNC(blas_strsm, 10, 1, jax::RealTrsm) +DEFINE_LAPACK_FUNC(blas_dtrsm, 10, 1, jax::RealTrsm) +DEFINE_LAPACK_FUNC(blas_ctrsm, 10, 1, jax::ComplexTrsm>) +DEFINE_LAPACK_FUNC(blas_ztrsm, 10, 1, jax::ComplexTrsm>) + +DEFINE_LAPACK_FUNC(lapack_sgetrf, 4, 3, jax::Getrf) +DEFINE_LAPACK_FUNC(lapack_dgetrf, 4, 3, jax::Getrf) +DEFINE_LAPACK_FUNC(lapack_cgetrf, 4, 3, jax::Getrf>) +DEFINE_LAPACK_FUNC(lapack_zgetrf, 4, 3, jax::Getrf>) + +DEFINE_LAPACK_FUNC(lapack_sgeqrf, 5, 4, jax::Geqrf) +DEFINE_LAPACK_FUNC(lapack_dgeqrf, 5, 4, jax::Geqrf) +DEFINE_LAPACK_FUNC(lapack_cgeqrf, 5, 4, jax::Geqrf>) +DEFINE_LAPACK_FUNC(lapack_zgeqrf, 5, 4, jax::Geqrf>) + +DEFINE_LAPACK_FUNC(lapack_sorgqr, 7, 3, jax::Orgqr) +DEFINE_LAPACK_FUNC(lapack_dorgqr, 7, 3, jax::Orgqr) +DEFINE_LAPACK_FUNC(lapack_cungqr, 7, 3, jax::Orgqr>) +DEFINE_LAPACK_FUNC(lapack_zungqr, 7, 3, jax::Orgqr>) + +DEFINE_LAPACK_FUNC(lapack_spotrf, 4, 2, jax::Potrf) +DEFINE_LAPACK_FUNC(lapack_dpotrf, 4, 2, jax::Potrf) +DEFINE_LAPACK_FUNC(lapack_cpotrf, 4, 2, jax::Potrf>) +DEFINE_LAPACK_FUNC(lapack_zpotrf, 4, 2, jax::Potrf>) + +DEFINE_LAPACK_FUNC(lapack_sgesdd, 7, 7, jax::RealGesdd) +DEFINE_LAPACK_FUNC(lapack_dgesdd, 7, 7, jax::RealGesdd) +DEFINE_LAPACK_FUNC(lapack_cgesdd, 7, 8, jax::ComplexGesdd>) +DEFINE_LAPACK_FUNC(lapack_zgesdd, 7, 8, jax::ComplexGesdd>) + +DEFINE_LAPACK_FUNC(lapack_ssyevd, 4, 5, jax::RealSyevd) +DEFINE_LAPACK_FUNC(lapack_dsyevd, 4, 5, jax::RealSyevd) +DEFINE_LAPACK_FUNC(lapack_cheevd, 4, 6, jax::ComplexHeevd>) +DEFINE_LAPACK_FUNC(lapack_zheevd, 4, 6, jax::ComplexHeevd>) + +DEFINE_LAPACK_FUNC(lapack_sgeev, 5, 6, jax::RealGeev) +DEFINE_LAPACK_FUNC(lapack_dgeev, 5, 6, jax::RealGeev) +DEFINE_LAPACK_FUNC(lapack_cgeev, 5, 6, jax::ComplexGeev>) +DEFINE_LAPACK_FUNC(lapack_zgeev, 5, 6, jax::ComplexGeev>) + +DEFINE_LAPACK_FUNC(lapack_sgees, 5, 6, jax::RealGees) +DEFINE_LAPACK_FUNC(lapack_dgees, 5, 6, jax::RealGees) +DEFINE_LAPACK_FUNC(lapack_cgees, 5, 6, jax::ComplexGees>) +DEFINE_LAPACK_FUNC(lapack_zgees, 5, 6, jax::ComplexGees>) + +DEFINE_LAPACK_FUNC(lapack_sgehrd, 7, 4, jax::Gehrd) +DEFINE_LAPACK_FUNC(lapack_dgehrd, 7, 4, jax::Gehrd) +DEFINE_LAPACK_FUNC(lapack_cgehrd, 7, 4, jax::Gehrd>) +DEFINE_LAPACK_FUNC(lapack_zgehrd, 7, 4, jax::Gehrd>) + +DEFINE_LAPACK_FUNC(lapack_ssytrd, 6, 6, jax::Sytrd) +DEFINE_LAPACK_FUNC(lapack_dsytrd, 6, 6, jax::Sytrd) +DEFINE_LAPACK_FUNC(lapack_chetrd, 6, 6, jax::Sytrd>) +DEFINE_LAPACK_FUNC(lapack_zhetrd, 6, 6, jax::Sytrd>) diff --git a/frontend/test/pytest/test_callback.py b/frontend/test/pytest/test_callback.py index 0a63c14065..ee2efd69bc 100644 --- a/frontend/test/pytest/test_callback.py +++ b/frontend/test/pytest/test_callback.py @@ -1386,14 +1386,10 @@ def mul_jax(A, B): @pytest.mark.parametrize("arg", [jnp.array([[0.1, 0.2], [0.3, 0.4]])]) -@pytest.mark.parametrize("order", ["good", "bad"]) +@pytest.mark.parametrize("order", ["truth_hypo", "hypo_truth"]) def test_vjp_as_residual(arg, order): """See https://github.com/PennyLaneAI/catalyst/issues/852""" - if order == "bad": - # See https://github.com/PennyLaneAI/catalyst/issues/894 - pytest.skip("Bug") - def jax_callback(fn, result_type): @pure_callback @@ -1421,7 +1417,7 @@ def hypothesis(x): def ground_truth(x): return jax.scipy.linalg.expm(x) - if order == "bad": + if order == "hypo_truth": obs = hypothesis(arg) exp = ground_truth(arg) else: @@ -1431,14 +1427,10 @@ def ground_truth(x): @pytest.mark.parametrize("arg", [jnp.array([[0.1, 0.2], [0.3, 0.4]])]) -@pytest.mark.parametrize("order", ["good", "bad"]) +@pytest.mark.parametrize("order", ["truth_hypo", "hypo_truth"]) def test_vjp_as_residual_automatic(arg, order): """Test automatic differentiation of accelerated function""" - if order == "bad": - # See https://github.com/PennyLaneAI/catalyst/issues/894 - pytest.skip("Bug") - @qml.qjit @jacobian def hypothesis(x): @@ -1448,7 +1440,7 @@ def hypothesis(x): def ground_truth(x): return jax.scipy.linalg.expm(x) - if order == "bad": + if order == "hypo_truth": obs = hypothesis(arg) exp = ground_truth(arg) else: diff --git a/frontend/test/pytest/test_contexts.py b/frontend/test/pytest/test_contexts.py index d4f22b3f0c..12de518444 100644 --- a/frontend/test/pytest/test_contexts.py +++ b/frontend/test/pytest/test_contexts.py @@ -17,11 +17,41 @@ import pennylane as qml import pytest -from catalyst import cond, grad, jacobian, measure, qjit, while_loop -from catalyst.tracing.contexts import EvaluationContext, EvaluationMode, GradContext +from catalyst import accelerate, cond, grad, jacobian, measure, qjit, while_loop +from catalyst.tracing.contexts import ( + AccelerateContext, + EvaluationContext, + EvaluationMode, + GradContext, +) # pylint: disable=protected-access +class TestAccelerateContext: + """Unit tests for accelerate context""" + + def test_in_accelerate_context(self): + """Test that AccelerateContext returns True when in an accelerate context.""" + + @qjit + @accelerate + def identity(x: float): + assert AccelerateContext.am_inside_accelerate() + return x + + identity(1.0) + + def test_not_in_accelerate_context(self): + """Test that AccelerateContext returns False when not in an accelerate context.""" + + @qjit + def identity(x: float): + assert not AccelerateContext.am_inside_accelerate() + return x + + identity(1.0) + + class TestGradContextUnitTests: """Unit tests for grad context""" diff --git a/frontend/test/pytest/test_device_api.py b/frontend/test/pytest/test_device_api.py index f793a4e07c..a2a470c651 100644 --- a/frontend/test/pytest/test_device_api.py +++ b/frontend/test/pytest/test_device_api.py @@ -19,7 +19,7 @@ import pennylane as qml import pytest from pennylane.devices import Device -from pennylane.devices.execution_config import DefaultExecutionConfig, ExecutionConfig +from pennylane.devices.execution_config import ExecutionConfig from pennylane.transforms import split_non_commuting from pennylane.transforms.core import TransformProgram @@ -60,8 +60,10 @@ def execute(self, circuits, execution_config): """Execution.""" return circuits, execution_config - def preprocess(self, execution_config: ExecutionConfig = DefaultExecutionConfig): + def preprocess(self, execution_config=None): """Preprocessing.""" + if execution_config is None: + execution_config = ExecutionConfig() transform_program = TransformProgram() transform_program.add_transform(split_non_commuting) return transform_program, execution_config @@ -105,7 +107,8 @@ def test_qjit_device(): # Check the preprocess of the new device with EvaluationContext(EvaluationMode.QUANTUM_COMPILATION) as ctx: - transform_program, _ = device_qjit.preprocess(ctx) + execution_config = ExecutionConfig() + transform_program, _ = device_qjit.preprocess(ctx, execution_config) assert transform_program assert len(transform_program) == 3 assert transform_program[-2]._transform.__name__ == "verify_operations" diff --git a/frontend/test/pytest/test_gradient.py b/frontend/test/pytest/test_gradient.py index 64f1103da1..792869ecd6 100644 --- a/frontend/test/pytest/test_gradient.py +++ b/frontend/test/pytest/test_gradient.py @@ -1535,7 +1535,7 @@ def f(x): return qml.expval(qml.PauliX(0)) def g(x): - return mitigate_with_zne(f, scale_factors=jax.numpy.array([1, 2, 3]))(x) + return mitigate_with_zne(f, scale_factors=[1, 3, 5])(x) with pytest.raises(CompileError, match=".*Compilation failed.*"): diff --git a/frontend/test/pytest/test_jax_linalg.py b/frontend/test/pytest/test_jax_linalg.py new file mode 100644 index 0000000000..cf4118f7ea --- /dev/null +++ b/frontend/test/pytest/test_jax_linalg.py @@ -0,0 +1,1072 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test that the jax linear algebra functions yield correct results when compiled with qml.qjit""" + +import numpy as np +import pytest +from jax import numpy as jnp +from jax import scipy as jsp + +from catalyst import qjit + +# pylint: disable=too-many-lines + + +class MatrixGenerator: + """ + A class for generating random matrices. + + Each static method instantiates its own random number generator with a fixed seed to + make the generated matrices deterministic and reproducible. + """ + + @staticmethod + def random_real_matrix(m, n, positive=False, seed=42, dtype=None): + """ + Generate a random m x n real matrix. + + By default, this method generates a matrix with elements that are real numbers + on the interval [-1, 1). If the `positive` option is True, then it generates a + positive matrix with elements on the interval [0, 1). + + This is a wrapper function for numpy.random.Generator.uniform: + + https://numpy.org/doc/stable/reference/random/generated/numpy.random.Generator.uniform.html + + Parameters + ---------- + m : int + Number of rows in the matrix. + n : int + Number of columns in the matrix. + positive : bool, optional + If true, generate a positive matrix (default is false). + seed : int, optional + Seed for the random number generator (default is 42). + dtype : str or numpy.dtype, optional + Data type of the returned matrix. If None (the default), no data-type + casting is performed. + + Returns + ------- + numpy.ndarray + An m x n matrix with random real values. + """ + rng = np.random.default_rng(seed) + lo = 0 if positive else -1 + hi = 1 + A = rng.uniform(lo, hi, (m, n)) + + if dtype is None: + return A + else: + return A.astype(dtype) + + @staticmethod + def random_integer_matrix(m, n, lo, hi, seed=42, dtype=None): + """ + Generate a random m x n integer matrix. + + The elements of the generated matrix are on the interval [`lo`, `hi`). + + This is a wrapper function for numpy.random.Generator.integers: + + https://numpy.org/doc/stable/reference/random/generated/numpy.random.Generator.integers.html + + Parameters + ---------- + m : int + Number of rows in the matrix. + n : int + Number of columns in the matrix. + lo : int + Lowest (signed) integers to be drawn from the distribution. + hi : int + One above the largest (signed) integer to be drawn from the distribution. + seed : int, optional + Seed for the random number generator (default is 42). + dtype : str or numpy.dtype, optional + Data type of the returned matrix. If None (the default), no data-type + casting is performed. + + Returns + ------- + numpy.ndarray + An m x n matrix with random integer values. + """ + rng = np.random.default_rng(seed) + A = rng.integers(lo, hi, (m, n)) + + if dtype is None: + return A + else: + return A.astype(dtype) + + @staticmethod + def random_complex_matrix(m, n, seed=42, dtype=None): + """ + Generate a random m x n complex matrix. + + This method generates two matrices A and B using numpy.random.Generator.uniform + and returns the sum C = A + iB. The real and imaginary parts of each element of + the generated matrix are on the interval [-1, 1). + + Parameters + ---------- + m : int + Number of rows in the matrix. + n : int + Number of columns in the matrix. + seed : int, optional + Seed for the random number generator (default is 42). + dtype : str or numpy.dtype, optional + Data type of the real and imaginary parts of the returned matrix. If None + (the default), no data-type casting is performed. For example, if + `dtype=np.float32`, the returned matrix is of type `np.complex64`. + + Returns + ------- + numpy.ndarray + An m x n matrix with random complex values. + """ + rng = np.random.default_rng(seed) + A = rng.uniform(-1, 1, (m, n)) + B = rng.uniform(-1, 1, (m, n)) + + if dtype is not None: + A = A.astype(dtype) + B = B.astype(dtype) + + return A + 1j * B + + def random_real_symmetric_matrix(n, positive=False, seed=42, dtype=None): + """ + Generate a random n x n real symmetric matrix. + + This method generates a matrix A with elements on the interval [-1, 1). If the + `positive` option is True, then it generates a positive matrix with elements on + the interval [0, 1). It then returns a symmetric matrix computed as + + S = (A + A^T) / 2. + + Parameters + ---------- + n : int + Number of rows and columns in the matrix. + positive : bool, optional + If true, generate a positive matrix (default is false). + seed : int, optional + Seed for the random number generator (default is 42). + dtype : str or numpy.dtype, optional + Data type of the returned matrix. If None (the default), no data-type + casting is performed. + + Returns + ------- + numpy.ndarray + An n x n symmetric matrix with random real values. + """ + rng = np.random.default_rng(seed) + lo = 0 if positive else -1 + hi = 1 + A = rng.uniform(lo, hi, (n, n)) + + if dtype is not None: + A = A.astype(dtype) + + S = (A + A.T) / 2 + assert np.allclose(S, S.T) # assert that the matrix is symmetric + + return S + + def random_real_symmetric_positive_definite_matrix(n, seed=42, dtype=None): + """ + Generate a random n x n real symmetric positive-definite matrix. + + This method generates a real lower-triangular positive matrix L and computes a + symmetric positive-definite matrix S using Cholesky decomposition: + + S = L L† + + Parameters + ---------- + n : int + Number of rows and columns in the matrix. + seed : int, optional + Seed for the random number generator (default is 42). + dtype : str or numpy.dtype, optional + Data type of the returned matrix. If None (the default), no data-type + casting is performed. + + Returns + ------- + numpy.ndarray + An n x n symmetric positive-definite matrix with random real values. + """ + L = np.tril( + MatrixGenerator.random_real_symmetric_matrix(n, positive=True, seed=seed, dtype=dtype) + ) + S = L * L.T + + assert np.allclose(S, S.T) # assert that the matrix is symmetric + assert np.all(np.linalg.eigvalsh(S) > 0) # assert that the matrix is positive-definite + + return S + + def random_complex_hermitian_matrix(n, seed=42, dtype=None): + """ + Generate a random n x n complex Hermitian matrix. + + This method generates two matrices A and B using numpy.random.Generator.uniform + and defines a complex matrix C = A + iB. The real and imaginary parts of each + element of the generated matrix are on the interval [-1, 1). It then returns a + Hermitian matrix computed as H = (C + C†) / 2. + + Parameters + ---------- + n : int + Number of rows and columns in the matrix. + seed : int, optional + Seed for the random number generator (default is 42). + dtype : str or numpy.dtype, optional + Data type of the real and imaginary parts of the returned matrix. If None + (the default), no data-type casting is performed. For example, if + `dtype=np.float32`, the returned matrix is of type `np.complex64`. + + Returns + ------- + numpy.ndarray + An n x n Hermitian matrix with random complex values. + """ + rng = np.random.default_rng(seed) + A = rng.uniform(-1, 1, (n, n)) + B = rng.uniform(-1, 1, (n, n)) + + if dtype is not None: + A = A.astype(dtype) + B = B.astype(dtype) + + C = A + 1j * B + + H = (C + C.T.conj()) / 2 + assert np.allclose(H, H.T.conj()) # assert that the matrix is Hermitian + + return H + + def random_complex_hermitian_positive_definite_matrix(n, seed=42, dtype=None): + """ + Generate a random n x n complex Hermitian positive-definite matrix. + + This method generates a complex lower-triangular matrix L and computes a + Hermitian positive-definite matrix A using Cholesky decomposition: + + H = L L† + + Parameters + ---------- + n : int + Number of rows and columns in the matrix. + seed : int, optional + Seed for the random number generator (default is 42). + dtype : str or numpy.dtype, optional + Data type of the real and imaginary parts of the returned matrix. If None + (the default), no data-type casting is performed. For example, if + `dtype=np.float32`, the returned matrix is of type `np.complex64`. + + Returns + ------- + numpy.ndarray + An n x n Hermitian positive-definite matrix with random complex values. + """ + L = np.tril(MatrixGenerator.random_complex_hermitian_matrix(n, seed, dtype)) + H = L * L.T.conj() + + assert np.allclose(H, H.T.conj()) # assert that the matrix is Hermitian + assert np.all(np.linalg.eigvalsh(H) > 0) # assert that the matrix is positive-definite + + return H + + +class TestCholesky: + """Test results of jax.scipy.linalg.cholesky are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html + + The Cholesky decomposition of a Hermitian positive-definite matrix H is a + factorization of the form + + H = L L† or H = U† U, + + where L is a lower-triangular matrix with real and positive diagonal entries, and + and similarly where U is an upper-triangular matrix. + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_symmetric_positive_definite_matrix(2, seed=11)), + jnp.array(MatrixGenerator.random_real_symmetric_positive_definite_matrix(3, seed=12)), + jnp.array(MatrixGenerator.random_real_symmetric_positive_definite_matrix(9, seed=13)), + # Complex matrices + jnp.array( + MatrixGenerator.random_complex_hermitian_positive_definite_matrix(2, seed=21) + ), + jnp.array( + MatrixGenerator.random_complex_hermitian_positive_definite_matrix(3, seed=22) + ), + jnp.array( + MatrixGenerator.random_complex_hermitian_positive_definite_matrix(9, seed=23) + ), + ], + ) + def test_cholesky_numerical_lower(self, A): + """Test basic numerical correctness of jax.scipy.linalg.cholesky with option + lower=True (giving decomposition in form A = L L†), for various functions and + matrices of various data types and sizes. + """ + + @qjit + def f(X): + return jsp.linalg.cholesky(X, lower=True) + + L_obs = f(A) + L_exp = jsp.linalg.cholesky(A, lower=True) + + assert np.allclose(L_exp @ L_exp.T.conj(), A) # Check jax solution is correct + assert jnp.allclose(L_obs, L_exp) + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_symmetric_positive_definite_matrix(2, seed=11)), + jnp.array(MatrixGenerator.random_real_symmetric_positive_definite_matrix(3, seed=12)), + jnp.array(MatrixGenerator.random_real_symmetric_positive_definite_matrix(9, seed=13)), + # Complex matrices + jnp.array( + MatrixGenerator.random_complex_hermitian_positive_definite_matrix(2, seed=21) + ), + jnp.array( + MatrixGenerator.random_complex_hermitian_positive_definite_matrix(3, seed=22) + ), + jnp.array( + MatrixGenerator.random_complex_hermitian_positive_definite_matrix(9, seed=23) + ), + ], + ) + def test_cholesky_numerical_upper(self, A): + """Test basic numerical correctness of jax.scipy.linalg.cholesky with option + lower=False (giving decomposition in form A = U† U), for various functions and + matrices of various data types and sizes. + """ + + @qjit + def f(X): + return jsp.linalg.cholesky(X, lower=False) + + U_obs = f(A) + U_exp = jsp.linalg.cholesky(A, lower=False) + + assert np.allclose(U_exp.T.conj() @ U_exp, A) # Check jax solution is correct + assert jnp.allclose(U_obs, U_exp) + + +class TestExpm: + """Test results of jax.scipy.linalg.expm are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html + + The exponential of a real or complex n x n matrix X, denoted exp(X), has the key + property that if XY = YX, then + + exp(X) exp(Y) = exp(X+Y). + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + # Integer matrices + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=21)), + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=0, hi=9, seed=22)), + jnp.array(np.triu(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=23))), + jnp.array(np.tril(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=24))), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=31)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=32))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=33))), + # Special case from https://github.com/PennyLaneAI/catalyst/issues/1071 + jnp.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [3.0, 2.0, 1.0]]), + ], + ) + def test_expm_numerical(self, A): + """Test basic numerical correctness of jax.scipy.linalg.expm for matrices of + various data types and sizes. + """ + + @qjit + def f(X): + return jsp.linalg.expm(X) + + expmA_obs = f(A) + expmA_exp = jsp.linalg.expm(A) + + assert jnp.allclose( + jsp.linalg.expm(A + A), jsp.linalg.expm(A) @ jsp.linalg.expm(A) + ) # Check jax solution is correct + assert np.allclose(expmA_obs, expmA_exp) + + @pytest.mark.parametrize( + "A", + [ + # Real upper-triangular matrices + jnp.array(np.triu(MatrixGenerator.random_real_matrix(2, 2, seed=41))), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=42))), + # Integer upper-triangular matrices + jnp.array(np.triu(MatrixGenerator.random_integer_matrix(2, 2, lo=-9, hi=9, seed=43))), + jnp.array(np.triu(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=44))), + # Complex upper-triangular matrices + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(2, 2, seed=45))), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=46))), + ], + ) + def test_expm_numerical_upper_triangular(self, A): + """Test basic numerical correctness of jax.scipy.linalg.expm for matrices of + various data types and sizes when using the `upper_triangular=True` option. + """ + + @qjit + def f(X): + return jsp.linalg.expm(X, upper_triangular=True) + + expmA_obs = f(A) + expmA_exp = jsp.linalg.expm(A, upper_triangular=True) + + assert jnp.allclose( + jsp.linalg.expm(A + A), jsp.linalg.expm(A) @ jsp.linalg.expm(A) + ) # Check jax solution is correct + assert np.allclose(expmA_obs, expmA_exp) + + +class TestFunmNumerical: + """Test results of jax.scipy.linalg.funm are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=22))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=23))), + ], + ) + def test_funm_numerical(self, A): + """Test basic numerical correctness of jax.scipy.linalg.funm for various + matrices of various data types and sizes. + """ + + @qjit + def f(X): + def func(X): + return jnp.sin(X) + 2 * jnp.cos(X) + + return jsp.linalg.funm(X, func) + + fA_obs = f(A) + fA_exp = jsp.linalg.funm(A, lambda X: jnp.sin(X) + 2 * jnp.cos(X)) + + assert jnp.allclose(fA_obs, fA_exp) + + +class TestHessenberg: + """Test results of jax.scipy.linalg.hessenberg are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html + + The Hessenberg form H of an n x n matrix A satisfies + + A = Q H Q†, + + where Q is unitary and H is zero below the first diagonal. + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=22))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=23))), + ], + ) + def test_hessenberg_numerical(self, A): + """Test basic numerical correctness of jax.scipy.linalg.hessenberg for matrices + of various data types and sizes. + + Note that jax does not support integer matrices for this function. + """ + + @qjit + def f(X): + return jsp.linalg.hessenberg(X, calc_q=True) + + H_obs, Q_obs = f(A) + H_exp, Q_exp = jsp.linalg.hessenberg(A, calc_q=True) + + assert jnp.allclose(Q_exp @ H_exp @ Q_exp.conj().T, A) # Check jax solution is correct + assert jnp.allclose(H_obs, H_exp) + assert jnp.allclose(Q_obs, Q_exp) + + +class TestLU: + """Test results of jax.scipy.linalg.lu are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html + + The LU decomposition with partial pivoting of a real or complex n x n matrix A satisfies + + A = P L U + + where P is a permutation matrix, L lower triangular with unit diagonal elements, and + U is upper triangular. + + JAX (and SciPy) also support LU decomposition of m x n matrices, in which case L has + dimension m x k, where k = min(m, n) and U has dimension k x n. + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(MatrixGenerator.random_real_matrix(5, 7, seed=15)), + jnp.array(MatrixGenerator.random_real_matrix(7, 5, seed=16)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=17))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=18))), + # Integer matrices + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=21)), + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=0, hi=9, seed=22)), + jnp.array(MatrixGenerator.random_integer_matrix(5, 7, lo=-9, hi=9, seed=23)), + jnp.array(MatrixGenerator.random_integer_matrix(7, 5, lo=-9, hi=9, seed=24)), + jnp.array(np.triu(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=25))), + jnp.array(np.tril(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=26))), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=31)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=32))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=33))), + ], + ) + def test_lu_numerical(self, A): + """Test basic numerical correctness for jax.scipy.linalg.lu for for matrices of + various data types and sizes. + """ + + @qjit + def f(X): + return jsp.linalg.lu(X) + + P_obs, L_obs, U_obs = f(A) + P_exp, L_exp, U_exp = jsp.linalg.lu(A) + + assert jnp.allclose(P_exp @ L_exp @ U_exp, A) # Check jax solution is correct + assert jnp.allclose(P_obs, P_exp) + assert jnp.allclose(L_obs, L_exp) + assert jnp.allclose(U_obs, U_exp) + + +class TestLUSolve: + """Test results of jax.scipy.linalg.lu_solve are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html + + This method solves a linear system using an LU factorization (see above). + """ + + @pytest.mark.parametrize( + "A,b", + [ + # Real coefficient matrices + ( + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(2, 1, seed=12)), + ), + ( + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 1, seed=14)), + ), + ( + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=15)), + jnp.array(MatrixGenerator.random_real_matrix(9, 1, seed=16)), + ), + # Complex coefficient matrices + ( + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(MatrixGenerator.random_complex_matrix(3, 1, seed=22)), + ), + ], + ) + def test_lu_solve_numerical(self, A, b): + """Test basic numerical correctness of jax.scipy.linalg.lu_solve for matrices and + vectors of various data types and sizes. + + Note that jax does not support integer matrices for this function. + """ + + @qjit + def f(A, b): + lu_and_piv = jsp.linalg.lu_factor(A) + return jsp.linalg.lu_solve(lu_and_piv, b) + + x_obs = f(A, b) + lu_and_piv = jsp.linalg.lu_factor(A) + x_exp = jsp.linalg.lu_solve(lu_and_piv, b) + + assert jnp.allclose(A @ x_exp, b) # Check jax solution is correct + assert jnp.allclose(x_obs, x_exp) + + +class TestPolar: + """Test results of jax.scipy.linalg.polar are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html + + The polar decomposition of a real or complex n x n matrix A is a factorization of + the form + + A = U P, + + where where U is a unitary matrix and P is a positive semi-definite Hermitian matrix. + + JAX (and SciPy) also support polar decomposition of m x n matrices, in which case U + has dimension m x n and P has dimension n x n. This is known as "right-side" polar + decomposition. "Left-side" decomposition, in the A = P U, is also possible, in which + case P has dimension m x m. + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(MatrixGenerator.random_real_matrix(5, 7, seed=15)), + jnp.array(MatrixGenerator.random_real_matrix(7, 5, seed=16)), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(MatrixGenerator.random_complex_matrix(5, 7, seed=22)), + jnp.array(MatrixGenerator.random_complex_matrix(7, 5, seed=23)), + ], + ) + def test_polar_numerical_svd(self, A): + """Test basic numerical correctness of jax.scipy.linalg.polar for matrices + of various data types and sizes using the 'svd' method. + + Note that jax does not support integer matrices for this function. + """ + + @qjit + def f(X): + return jsp.linalg.polar(X, method="svd") + + U_obs, P_obs = f(A) + U_exp, P_exp = jsp.linalg.polar(A, method="svd") + + assert jnp.allclose(U_exp @ P_exp, A) # Check jax solution is correct + assert jnp.allclose(U_obs, U_exp) + assert jnp.allclose(P_obs, P_exp) + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(MatrixGenerator.random_real_matrix(7, 5, seed=15)), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(MatrixGenerator.random_complex_matrix(7, 5, seed=22)), + ], + ) + def test_polar_numerical_qdwh(self, A): + """Test basic numerical correctness of jax.scipy.linalg.polar for matrices + of various data types and sizes using the 'qdwh' method. + + Note that jax does not support integer matrices for this function. + """ + print("Start") + + @qjit + def f(X): + return jsp.linalg.polar(X, method="qdwh") + + U_obs, P_obs = f(A) + U_exp, P_exp = jsp.linalg.polar(A, method="qdwh") + + assert jnp.allclose(U_exp @ P_exp, A) # Check jax solution is correct + assert jnp.allclose(U_obs, U_exp) + assert jnp.allclose(P_obs, P_exp) + + +class TestQR: + """Test results of jax.scipy.linalg.qr are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html + + QR decomposition of a real or complex n x n matrix A is a factorization of the form + + A = Q R, + + where Q is a unitary matrix (i.e. Q† Q = Q Q† = I) and R is an upper-triangular + matrix (also called right-triangular matrix). + + JAX (and SciPy) also support QR decomposition of m x n matrices, in which case Q has + dimension m x k, where k = min(m, n) and R has dimension k x n. + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + jnp.array(MatrixGenerator.random_real_matrix(5, 7, seed=17)), + jnp.array(MatrixGenerator.random_real_matrix(7, 5, seed=18)), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=22))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=23))), + jnp.array(MatrixGenerator.random_complex_matrix(5, 7, seed=24)), + jnp.array(MatrixGenerator.random_complex_matrix(7, 5, seed=25)), + ], + ) + def test_qr_numerical(self, A): + """Test basic numerical correctness of jax.scipy.linalg.qr for matrices + of various data types and sizes. + + Note that jax does not support integer matrices for this function. + """ + + @qjit + def f(X): + return jsp.linalg.qr(X) + + Q_obs, R_obs = f(A) + Q_exp, R_exp = jsp.linalg.qr(A) + + assert jnp.allclose(Q_exp @ R_exp, A) # Check jax solution is correct + assert jnp.allclose(Q_obs, Q_exp) + assert jnp.allclose(R_obs, R_exp) + + +class TestSchur: + """Test results of jax.scipy.linalg.schur are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html + + The Schur decomposition of a real or complex n x n matrix A is a factorization of + the form + + A = Z T Z† + + where Z is a unitary matrix and T is an upper-triangular matrix. + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=22))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=23))), + ], + ) + def test_schur_numerical_real(self, A): + """Test basic numerical correctness of jax.scipy.linalg.schur with output='real' + for matrices of various data types and sizes. + + Note that jax does not support integer matrices for this function. + """ + + @qjit + def f(X): + return jsp.linalg.schur(X, output="real") + + T_obs, Z_obs = f(A) + T_exp, Z_exp = jsp.linalg.schur(A, output="real") + + assert jnp.allclose(Z_exp @ T_exp @ Z_exp.conj().T, A) # Check jax solution is correct + assert jnp.allclose(T_obs, T_exp) + assert jnp.allclose(Z_obs, Z_exp) + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=21)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=22))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=23))), + ], + ) + def test_schur_numerical_complex(self, A): + """Test basic numerical correctness of jax.scipy.linalg.schur with output='complex' + for matrices of various data types and sizes. + + Note that jax does not support integer matrices for this function. + """ + + @qjit + def f(X): + return jsp.linalg.schur(X, output="complex") + + T_obs, Z_obs = f(A) + T_exp, Z_exp = jsp.linalg.schur(A, output="complex") + + assert jnp.allclose(Z_exp @ T_exp @ Z_exp.conj().T, A) # Check jax solution is correct + assert jnp.allclose(T_obs, T_exp) + assert jnp.allclose(Z_obs, Z_exp) + + +class TestSolve: + """Test results of jax.scipy.linalg.solve are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html + + Solves a linear system of equations of the form A x = b for x given the n x n matrix + A and length n vector b. + """ + + @pytest.mark.parametrize( + "A,b", + [ + # Real coefficient matrices + ( + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(2, 1, seed=12)), + ), + ( + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 1, seed=14)), + ), + ( + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=15)), + jnp.array(MatrixGenerator.random_real_matrix(9, 1, seed=16)), + ), + # Integer coefficient matrices + ( + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, -9, 9, seed=21)), + jnp.array(MatrixGenerator.random_integer_matrix(3, 1, -9, 9, seed=22)), + ), + # Complex coefficient matrices + ( + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=31)), + jnp.array(MatrixGenerator.random_complex_matrix(3, 1, seed=22)), + ), + ], + ) + def test_solve_numerical(self, A, b): + """Test basic numerical correctness of jax.scipy.linalg.solve for matrices and + vectors of various data types and sizes. + """ + + @qjit + def f(A, b): + return jsp.linalg.solve(A, b) + + x_obs = f(A, b) + x_exp = jsp.linalg.solve(A, b) + + assert jnp.allclose(A @ x_exp, b) # Check jax solution is correct + assert jnp.allclose(x_obs, x_exp) + + @pytest.mark.parametrize( + "A,b", + [ + # Hermitian coefficient matrices + ( + jnp.array(MatrixGenerator.random_complex_hermitian_matrix(3, seed=11)), + jnp.array(MatrixGenerator.random_complex_matrix(3, 1, seed=12)), + ), + ], + ) + def test_solve_numerical_hermitian(self, A, b): + """Test basic numerical correctness of jax.scipy.linalg.solve for Hermitian + matrices and vectors of various data types and sizes to test the `assume_a="her"` + option. + """ + + @qjit + def f(A, b): + return jsp.linalg.solve(A, b, assume_a="her") + + x_obs = f(A, b) + x_exp = jsp.linalg.solve(A, b, assume_a="her") + + assert jnp.allclose(A @ x_exp, b) # Check jax solution is correct + assert jnp.allclose(x_obs, x_exp) + + +class TestSqrtm: + """Test results of jax.scipy.linalg.sqrtm are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html + + The square root of a real or complex n x n matrix A, denoted sqrt(A), has the key + property that + + sqrt(A) sqrt(A) = A + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + # Integer matrices + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=21)), + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=0, hi=9, seed=22)), + jnp.array(np.triu(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=23))), + jnp.array(np.tril(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=24))), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=31)), + jnp.array(np.triu(MatrixGenerator.random_complex_matrix(3, 3, seed=32))), + jnp.array(np.tril(MatrixGenerator.random_complex_matrix(3, 3, seed=33))), + ], + ) + def test_sqrtm_numerical(self, A): + """Test basic numerical correctness of jax.scipy.linalg.sqrtm for matrices of + various data types and sizes. + """ + + @qjit + def f(X): + return jsp.linalg.sqrtm(X) + + sqrtmA_obs = f(A) + sqrtmA_exp = jsp.linalg.sqrtm(A) + + assert jnp.allclose(sqrtmA_exp @ sqrtmA_exp, A) # Check jax solution is correct + assert jnp.allclose(sqrtmA_obs, sqrtmA_exp) + + +class TestSVD: + """Test results of jax.scipy.linalg.svd are numerically correct when qjit compiled. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html + + Singular value decomposition (SVD) of a real or complex m x n matrix A is a + factorization of the form + + A = U Σ V†, + + where U and V are m x n unitary matrices containing the left and right singular + vectors, respectively, and Σ is an m x n diagonal matrix of singular values. + """ + + @pytest.mark.parametrize( + "A", + [ + # Real matrices + jnp.array(MatrixGenerator.random_real_matrix(2, 2, seed=11)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, seed=12)), + jnp.array(MatrixGenerator.random_real_matrix(9, 9, seed=13)), + jnp.array(MatrixGenerator.random_real_matrix(3, 3, positive=True, seed=14)), + jnp.array(np.triu(MatrixGenerator.random_real_matrix(3, 3, seed=15))), + jnp.array(np.tril(MatrixGenerator.random_real_matrix(3, 3, seed=16))), + jnp.array(MatrixGenerator.random_real_matrix(5, 7, seed=17)), + jnp.array(MatrixGenerator.random_real_matrix(7, 5, seed=18)), + # Integer matrices + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=-9, hi=9, seed=21)), + jnp.array(MatrixGenerator.random_integer_matrix(3, 3, lo=0, hi=9, seed=22)), + # Complex matrices + jnp.array(MatrixGenerator.random_complex_matrix(3, 3, seed=31)), + jnp.array(MatrixGenerator.random_complex_matrix(5, 7, seed=32)), + jnp.array(MatrixGenerator.random_complex_matrix(7, 5, seed=33)), + ], + ) + def test_svd_numerical(self, A): + """Test basic numerical correctness of jax.scipy.linalg.svd for matrices + of various data types and sizes. + """ + + @qjit + def f(X): + return jsp.linalg.svd(X) + + U_obs, S_obs, Vt_obs = f(A) + U_exp, S_exp, Vt_exp = jsp.linalg.svd(A) + + # Pad S_exp with rows/cols of zeros if input matrix is not square + S_padded = np.zeros(A.shape) + for i in range(min(A.shape)): + S_padded[i, i] = S_exp[i] + + assert jnp.allclose(U_exp @ S_padded @ Vt_exp, A) # Check jax solution is correct + assert jnp.allclose(U_obs, U_exp) + assert jnp.allclose(S_obs, S_exp) + assert jnp.allclose(Vt_obs, Vt_exp) diff --git a/frontend/test/pytest/test_jax_linalg_in_circuit.py b/frontend/test/pytest/test_jax_linalg_in_circuit.py new file mode 100644 index 0000000000..16a9d2226c --- /dev/null +++ b/frontend/test/pytest/test_jax_linalg_in_circuit.py @@ -0,0 +1,48 @@ +# Copyright 2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Test the jax linear algebra functions for full quantum-circuit workflows compiled with qjit. +""" + +import numpy as np +import pennylane as qml +from jax import numpy as jnp +from jax import scipy as jsp + +from catalyst import qjit + + +class TestExpmInCircuit: + """Test entire quantum workflows with jax.scipy.linag.expm""" + + def test_expm_in_circuit(self): + """Rotate |0> about Bloch x axis for 180 degrees to get |1>""" + + @qjit + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit_expm(): + generator = -1j * jnp.pi * jnp.array([[0, 1], [1, 0]]) / 2 + unitary = jsp.linalg.expm(generator) + qml.QubitUnitary(unitary, wires=[0]) + return qml.probs() + + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit_rot(): + qml.RX(np.pi, wires=[0]) + return qml.probs() + + res = circuit_expm() + expected = circuit_rot() # expected = [0,1] + assert np.allclose(res, expected) diff --git a/frontend/test/pytest/test_jax_numerical.py b/frontend/test/pytest/test_jax_numerical.py index 4c42b3bc3d..f4d739874a 100644 --- a/frontend/test/pytest/test_jax_numerical.py +++ b/frontend/test/pytest/test_jax_numerical.py @@ -14,19 +14,21 @@ """Test that numerical jax functions produce correct results when compiled with qml.qjit""" -import warnings - import numpy as np import pennylane as qml import pytest from jax import numpy as jnp from jax import scipy as jsp -from catalyst import accelerate, qjit +from catalyst import qjit + +class TestExpmAndSolve: + """Test that `jax.scipy.linalg.expm` and `jax.scipy.linalg.solve` can run together + in the same function scope but from different qjit blocks. -class TestExpmNumerical: - """Test jax.scipy.linalg.expm is numerically correct when being qjit compiled""" + Also test that their results are numerically correct when qjit compiled. + """ @pytest.mark.parametrize( "inp", @@ -53,6 +55,28 @@ def f(x): assert np.allclose(observed, expected) + def test_expm_and_solve(self): + """ + Test against the "gather rule not implemented" bug for + using expm and solve together. + https://github.com/PennyLaneAI/catalyst/issues/1094 + """ + + A = jnp.array([[1.0, 0.0], [0.0, 1.0]]) + b = jnp.array([[0.1], [0.2]]) + + def f(A, b): + return jsp.linalg.solve(A, b) + + def g(A): + return jsp.linalg.expm(A) + + expected = [g(A), f(A, b)] # [e, 0; 0, e], [0.1; 0.2] + observed = [qjit(g)(A), qjit(f)(A, b)] + + assert np.allclose(expected[0], observed[0]) + assert np.allclose(expected[1], observed[1]) + class TestExpmInCircuit: """Test entire quantum workflows with jax.scipy.linag.expm""" @@ -78,123 +102,6 @@ def circuit_rot(): assert np.allclose(res, expected) -class TestExpmWarnings: - """Test jax.scipy.linalg.expm raises a warning when not used in accelerate callback""" - - """Remove the warnings module and this test when we have proper lapack calls""" - - def test_expm_warnings(self): - @qjit - def f(x): - expm = jsp.linalg.expm - return expm(x) - - with pytest.warns( - UserWarning, - match="jax.scipy.linalg.expm occasionally gives wrong numerical results", - ): - f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - - def test_accelerated_expm_no_warnings(self, recwarn): - @qjit - def f(x): - expm = accelerate(jsp.linalg.expm) - return expm(x) - - observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - expected = jsp.linalg.expm(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - assert len(recwarn) == 0 - assert np.allclose(observed, expected) - - -class TestLUWarnings: - """Test jax.scipy.linalg.lu raises a warning when not used in accelerate callback""" - - """Remove the warnings module and this test when we have proper lapack calls""" - - def test_lu_warnings(self): - @qjit - def f(x): - lu = jsp.linalg.lu - return lu(x) - - with pytest.warns( - UserWarning, - match="jax.scipy.linalg.lu occasionally gives wrong numerical results", - ): - f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - - def test_accelerated_lu_no_warnings(self, recwarn): - @qjit - def f(x): - lu = accelerate(jsp.linalg.lu) - return lu(x) - - observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - expected = jsp.linalg.lu(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - assert len(recwarn) == 0 - assert np.allclose(observed, expected) - - def test_lu_factor_warnings(self): - @qjit - def f(x): - luf = jsp.linalg.lu_factor - return luf(x) - - with pytest.warns( - UserWarning, - match="jax.scipy.linalg.lu_factor occasionally gives wrong numerical results", - ): - f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - - def test_accelerated_lu_factor_no_warnings(self, recwarn): - @qjit - def f(x): - luf = accelerate(jsp.linalg.lu_factor) - return luf(x) - - observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - expected = jsp.linalg.lu_factor(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - assert len(recwarn) == 0 - assert np.allclose(observed[0], expected[0]) - assert np.allclose(observed[1], expected[1]) - - def test_lu_solve_warnings(self): - @qjit - def f(x): - lus = jsp.linalg.lu_solve - b = jnp.array([3.0, 4.0]) - B = accelerate(jsp.linalg.lu_factor)( - x - ) # since this is a lu_solve unit test, use accelerate for lu_factor - return lus(B, b) - - with pytest.warns( - UserWarning, - match="jax.scipy.linalg.lu_solve occasionally gives wrong numerical results", - ): - f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - - def test_accelerated_lu_solve_no_warnings(self, recwarn): - @qjit - def f(x): - lus = accelerate(jsp.linalg.lu_solve) - b = jnp.array([3.0, 4.0]) - B = accelerate(jsp.linalg.lu_factor)(x) - return lus(B, b) - - def truth(x): - lus = jsp.linalg.lu_solve - b = jnp.array([3.0, 4.0]) - B = jsp.linalg.lu_factor(x) - return lus(B, b) - - observed = f(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - expected = truth(jnp.array([[0.1, 0.2], [5.3, 1.2]])) - assert len(recwarn) == 0 - assert np.allclose(observed, expected) - - class TestArgsortNumerical: """Test jax.numpy.argsort sort arrays correctly when being qjit compiled""" diff --git a/frontend/test/pytest/test_mid_circuit_measurement.py b/frontend/test/pytest/test_mid_circuit_measurement.py index 6349f5c6e0..c18cc46670 100644 --- a/frontend/test/pytest/test_mid_circuit_measurement.py +++ b/frontend/test/pytest/test_mid_circuit_measurement.py @@ -403,13 +403,13 @@ def circuit(): @qjit def mitigated_circuit_1(): - s = jnp.array([1, 2]) + s = [1, 3] g = qml.QNode(circuit, dev, mcm_method="one-shot") return mitigate_with_zne(g, scale_factors=s)() @qjit def mitigated_circuit_2(): - s = jnp.array([1, 2]) + s = [1, 3] g = qml.QNode(circuit, dev) return mitigate_with_zne(g, scale_factors=s)() diff --git a/frontend/test/pytest/test_mitigation.py b/frontend/test/pytest/test_mitigation.py index ec0dfda7a3..64b491423f 100644 --- a/frontend/test/pytest/test_mitigation.py +++ b/frontend/test/pytest/test_mitigation.py @@ -26,17 +26,19 @@ quadratic_extrapolation = polynomial_extrapolation(2) -def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func): +def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func, threshold): """skip test if exponential extrapolation will be unstable""" - if circuit_param < 0.3 and extrapolation_func == exponential_extrapolate: + if circuit_param <= threshold and extrapolation_func == exponential_extrapolate: pytest.skip("Exponential extrapolation unstable in this region.") @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) @pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) -def test_single_measurement(params, extrapolation): +@pytest.mark.parametrize("scale_factors", [[1, 3, 5, 7], [3, 7, 21, 29]]) +@pytest.mark.parametrize("folding", ["global", "local-all"]) +def test_single_measurement(params, extrapolation, folding, scale_factors): """Test that without noise the same results are returned for single measurements.""" - skip_if_exponential_extrapolation_unstable(params, extrapolation) + skip_if_exponential_extrapolation_unstable(params, extrapolation, threshold=0.2) dev = qml.device("lightning.qubit", wires=2) @@ -52,7 +54,10 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + circuit, + scale_factors=scale_factors, + extrapolate=extrapolation, + folding=folding, )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @@ -60,9 +65,10 @@ def mitigated_qnode(args): @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) @pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) -def test_multiple_measurements(params, extrapolation): +@pytest.mark.parametrize("folding", ["global", "local-all"]) +def test_multiple_measurements(params, extrapolation, folding): """Test that without noise the same results are returned for multiple measurements""" - skip_if_exponential_extrapolation_unstable(params, extrapolation) + skip_if_exponential_extrapolation_unstable(params, extrapolation, threshold=0.5) dev = qml.device("lightning.qubit", wires=2) @@ -78,14 +84,18 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + circuit, + scale_factors=[1, 3, 5, 7], + extrapolate=extrapolation, + folding=folding, )(args) assert np.allclose(mitigated_qnode(params), circuit(params)) @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) -def test_single_measurement_control_flow(params): +@pytest.mark.parametrize("folding", ["global", "local-all"]) +def test_single_measurement_control_flow(params, folding): """Test that without noise the same results are returned for single measurement and with control flow.""" dev = qml.device("lightning.qubit", wires=2) @@ -113,13 +123,28 @@ def loop_1(i): # pylint: disable=unused-argument @catalyst.qjit def mitigated_qnode(args, n): - return catalyst.mitigate_with_zne(circuit, scale_factors=jax.numpy.array([1, 2, 3]))( + return catalyst.mitigate_with_zne(circuit, scale_factors=[1, 3, 5, 7], folding=folding)( args, n ) assert np.allclose(mitigated_qnode(params, 3), catalyst.qjit(circuit)(params, 3)) +@pytest.mark.parametrize("scale_factors", [[1.0, 3, 5, 7], [-1, 3, 5, 7], [1, 2, 5, 7]]) +def test_scale_factors_error(scale_factors): + """Test that when scale factors are not positive odd integer, it raises an error.""" + + def circuit(x): + return jax.numpy.sin(x) + + @catalyst.qjit + def mitigated_function(args): + return catalyst.mitigate_with_zne(circuit, scale_factors=scale_factors)(args) + + with pytest.raises(ValueError, match="The scale factors must be positive odd integers:"): + mitigated_function(0.1) + + def test_not_qnode_error(): """Test that when applied not on a QNode the transform raises an error.""" @@ -128,7 +153,7 @@ def circuit(x): @catalyst.qjit def mitigated_function(args): - return catalyst.mitigate_with_zne(circuit, scale_factors=jax.numpy.array([1, 2, 3]))(args) + return catalyst.mitigate_with_zne(circuit, scale_factors=[1, 3, 5, 7])(args) with pytest.raises(TypeError, match="A QNode is expected, got the classical function"): mitigated_function(0.1) @@ -151,7 +176,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + circuit, scale_factors=[1, 3, 5, 7], extrapolate=extrapolation )(args) with pytest.raises( @@ -177,7 +202,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + circuit, scale_factors=[1, 3, 5, 7], extrapolate=extrapolation )(args) with pytest.raises( @@ -203,7 +228,7 @@ def circuit(x): @catalyst.qjit def mitigated_qnode(args): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + circuit, scale_factors=[1, 3, 5, 7], extrapolate=extrapolation )(args) with pytest.raises( @@ -238,7 +263,7 @@ def circuit(): return 0.0 def mitigated_qnode(): - return catalyst.mitigate_with_zne(circuit, scale_factors=[], folding="all")() + return catalyst.mitigate_with_zne(circuit, scale_factors=[], folding="local-random")() with pytest.raises(NotImplementedError): catalyst.qjit(mitigated_qnode) @@ -246,9 +271,10 @@ def mitigated_qnode(): @pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5]) @pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate]) -def test_zne_usage_patterns(params, extrapolation): +@pytest.mark.parametrize("folding", ["global", "local-all"]) +def test_zne_usage_patterns(params, extrapolation, folding): """Test usage patterns of catalyst.zne.""" - skip_if_exponential_extrapolation_unstable(params, extrapolation) + skip_if_exponential_extrapolation_unstable(params, extrapolation, threshold=0.2) dev = qml.device("lightning.qubit", wires=2) @@ -264,13 +290,13 @@ def fn(x): @catalyst.qjit def mitigated_qnode_fn_as_argument(args): return catalyst.mitigate_with_zne( - fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + fn, scale_factors=[1, 3, 5, 7], extrapolate=extrapolation, folding=folding )(args) @catalyst.qjit def mitigated_qnode_partial(args): return catalyst.mitigate_with_zne( - scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation + scale_factors=[1, 3, 5, 7], extrapolate=extrapolation, folding=folding )(fn)(args) assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params)) @@ -296,7 +322,7 @@ def jax_extrapolation(scale_factors, results): @catalyst.qjit def mitigated_qnode(): return catalyst.mitigate_with_zne( - circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrapolation + circuit, scale_factors=[1, 3, 5, 7], extrapolate=jax_extrapolation )() assert np.allclose(mitigated_qnode(), circuit()) @@ -319,7 +345,7 @@ def circuit(): def mitigated_qnode(): return catalyst.mitigate_with_zne( circuit, - scale_factors=jax.numpy.array([1, 2, 3]), + scale_factors=[1, 3, 5, 7], extrapolate=qml.transforms.poly_extrapolate, extrapolate_kwargs={"order": 2}, )() @@ -344,7 +370,7 @@ def circuit(): def mitigated_qnode(): return catalyst.mitigate_with_zne( circuit, - scale_factors=jax.numpy.array([1, 2, 3]), + scale_factors=[1, 3, 5, 7], extrapolate=qml.transforms.exponential_extrapolate, extrapolate_kwargs={"asymptote": 3}, )() diff --git a/mlir/include/Mitigation/IR/MitigationOps.td b/mlir/include/Mitigation/IR/MitigationOps.td index 8cb682d26b..11f574e398 100644 --- a/mlir/include/Mitigation/IR/MitigationOps.td +++ b/mlir/include/Mitigation/IR/MitigationOps.td @@ -27,8 +27,8 @@ def Folding : I32EnumAttr<"Folding", "Folding types", [ I32EnumAttrCase<"global", 1>, - I32EnumAttrCase<"random", 2>, - I32EnumAttrCase<"all", 3>, + I32EnumAttrCase<"all", 2>, + I32EnumAttrCase<"random", 3>, ]> { let cppNamespace = "catalyst::mitigation"; let genSpecializedAttr = 0; @@ -48,12 +48,12 @@ def ZneOp : Mitigation_Op<"zne", [DeclareOpInterfaceMethods, FlatSymbolRefAttr:$callee, Variadic:$args, FoldingAttr:$folding, - RankedTensorOf<[AnySignlessIntegerOrIndex]>:$scaleFactors + RankedTensorOf<[AnySignlessIntegerOrIndex]>:$numFolds ); let results = (outs Variadic]>>); let assemblyFormat = [{ - $callee `(` $args `)` `folding` `(` $folding `)` `scaleFactors` `(` $scaleFactors `:` type($scaleFactors) `)` attr-dict `:` functional-type($args, results) + $callee `(` $args `)` `folding` `(` $folding `)` `numFolds` `(` $numFolds `:` type($numFolds) `)` attr-dict `:` functional-type($args, results) }]; } diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp index 8021c5f731..588d6493e9 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.cpp @@ -40,20 +40,23 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - // Scalar factors - auto scaleFactors = op.getScaleFactors(); - RankedTensorType scaleFactorType = cast(scaleFactors.getType()); - const auto sizeInt = scaleFactorType.getDimSize(0); + // Number of folds + auto numFolds = op.getNumFolds(); + RankedTensorType numFoldType = cast(numFolds.getType()); + const auto sizeInt = numFoldType.getDimSize(0); + + // Folding type + auto foldingAlgorithm = op.getFolding(); // Create the folded circuit function FlatSymbolRefAttr foldedCircuitRefAttr = - getOrInsertFoldedCircuit(loc, rewriter, op, scaleFactorType.getElementType()); + getOrInsertFoldedCircuit(loc, rewriter, op, foldingAlgorithm); func::FuncOp foldedCircuit = SymbolTable::lookupNearestSymbolFrom(op, foldedCircuitRefAttr); RankedTensorType resultType = cast(op.getResultTypes().front()); - // Loop over the scalars to create a folded circuit per factor + // Loop over the num fold to create a folded circuit per factor Value c0 = rewriter.create(loc, 0); Value c1 = rewriter.create(loc, 1); Value size = rewriter.create(loc, sizeInt); @@ -67,11 +70,10 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { std::vector newArgs(op.getArgs().begin(), op.getArgs().end()); SmallVector index = {i}; - Value scalarFactor = - builder.create(loc, scaleFactors, index); - Value scalarFactorCasted = - builder.create(loc, builder.getIndexType(), scalarFactor); - newArgs.push_back(scalarFactorCasted); + Value numFold = builder.create(loc, numFolds, index); + Value numFoldCasted = + builder.create(loc, builder.getIndexType(), numFold); + newArgs.push_back(numFoldCasted); func::CallOp callOp = builder.create(loc, foldedCircuit, newArgs); int64_t numResults = callOp.getNumResults(); @@ -123,78 +125,28 @@ void ZneLowering::rewrite(mitigation::ZneOp op, PatternRewriter &rewriter) const // Replace the original results rewriter.replaceOp(op, resultValues); } - -FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, - mitigation::ZneOp op, Type scalarType) +// In *.cpp module only, to keep extraneous headers out of *.hpp +FlatSymbolRefAttr globalFolding(Location loc, PatternRewriter &rewriter, std::string fnFoldedName, + StringAttr lib, StringAttr name, StringAttr kwargs, + int64_t numberQubits, FunctionType fnFoldedType, + SmallVector typesFolded, func::FuncOp fnFoldedOp, + func::FuncOp fnAllocOp, func::FuncOp fnWithoutMeasurementsOp, + func::FuncOp fnWithMeasurementsOp) { - MLIRContext *ctx = rewriter.getContext(); - - OpBuilder::InsertionGuard guard(rewriter); - ModuleOp moduleOp = op->getParentOfType(); - std::string fnFoldedName = op.getCallee().str() + ".folded"; - - if (moduleOp.lookupSymbol(fnFoldedName)) { - return SymbolRefAttr::get(ctx, fnFoldedName); - } - - // Original function - func::FuncOp fnOp = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); - TypeRange originalTypes = op.getArgs().getTypes(); - Type qregType = quantum::QuregType::get(rewriter.getContext()); - - // Set insertion in the module - rewriter.setInsertionPointToStart(moduleOp.getBody()); - // Quantum Alloc function - FlatSymbolRefAttr quantumAllocRefAttr = getOrInsertQuantumAlloc(loc, rewriter, op); - func::FuncOp fnAllocOp = - SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); - - // Get the number of qubits - quantum::AllocOp allocOp = *fnOp.getOps().begin(); - std::optional numberQubitsOptional = allocOp.getNqubitsAttr(); - int64_t numberQubits = numberQubitsOptional.value_or(0); - // Get the device - quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); - StringAttr lib = deviceInitOp.getLibAttr(); - StringAttr name = deviceInitOp.getNameAttr(); - StringAttr kwargs = deviceInitOp.getKwargsAttr(); - - // Function without measurements: Create function without measurements and with qreg as last - // argument - FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = - getOrInsertFnWithoutMeasurements(loc, rewriter, op); - func::FuncOp fnWithoutMeasurementsOp = - SymbolTable::lookupNearestSymbolFrom(op, fnWithoutMeasurementsRefAttr); - - // Function with measurements: Modify the original function to take a quantum register as last - // arg and keep measurements - FlatSymbolRefAttr fnWithMeasurementsRefAttr = getOrInsertFnWithMeasurements(loc, rewriter, op); - func::FuncOp fnWithMeasurementsOp = - SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); - // Function folded: Create the folded circuit (withoutMeasurement * - // Adjoint(withoutMeasurement))**scalar_factor * withMeasurements - rewriter.setInsertionPointToStart(moduleOp.getBody()); - SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); - Type indexType = rewriter.getIndexType(); - typesFolded.push_back(indexType); - FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ - typesFolded, - /*outputs=*/fnOp.getResultTypes()); - - func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); - fnFoldedOp.setPrivate(); + // Adjoint(withoutMeasurement))**num_fold * withMeasurements + Type qregType = quantum::QuregType::get(rewriter.getContext()); - Block *foldedBloc = fnFoldedOp.addEntryBlock(); - rewriter.setInsertionPointToStart(foldedBloc); - // Add device - rewriter.create(loc, lib, name, kwargs); + rewriter.setInsertionPointToStart(fnFoldedOp.addEntryBlock()); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); TypedAttr numberQubitsAttr = rewriter.getI64IntegerAttr(numberQubits); Value numberQubitsValue = rewriter.create(loc, numberQubitsAttr); + rewriter.create(loc, lib, name, kwargs); + Value allocQreg = rewriter.create(loc, fnAllocOp, numberQubitsValue).getResult(0); - Value c0 = rewriter.create(loc, 0); - Value c1 = rewriter.create(loc, 1); int64_t sizeArgs = fnFoldedOp.getArguments().size(); Value size = fnFoldedOp.getArgument(sizeArgs - 1); // Add scf for loop to create the folding @@ -243,7 +195,146 @@ FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRew // Remove device rewriter.create(loc); rewriter.create(loc, funcFolded); - return SymbolRefAttr::get(ctx, fnFoldedName); + return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); +} +// In *.cpp module only, to keep extraneous headers out of *.hpp +FlatSymbolRefAttr randomLocalFolding(PatternRewriter &rewriter, std::string fnFoldedName, + func::FuncOp fnFoldedOp, Value c0, Value c1) +{ + // TODO: Implement. + + // Can't throw, because disabled by compilation. + // throw std::logic_error("Random local folding not implemented!"); + + return FlatSymbolRefAttr(); +} +// In *.cpp module only, to keep extraneous headers out of *.hpp +FlatSymbolRefAttr allLocalFolding(PatternRewriter &rewriter, std::string fnFoldedName, + func::FuncOp fnFoldedOp, Value c0, Value c1) +{ + + int64_t sizeArgs = fnFoldedOp.getArguments().size(); + Value size = fnFoldedOp.getArgument(sizeArgs - 1); + + // Walk through the operations in fnFoldedOp + fnFoldedOp.walk([&](quantum::QuantumGate op) { + rewriter.setInsertionPoint(op); + auto loc = op->getLoc(); + const std::vector opQubitArgs = op.getQubitOperands(); + + // Insert a for loop immediately before each quantum::QuantumGate + const auto forVal = + rewriter + .create( + loc, c0, size, c1, /*iterArgsInit=*/opQubitArgs, + [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { + // Create adjoint and original operations + quantum::QuantumGate origOp = + dyn_cast(builder.clone(*op)); + origOp.setQubitOperands(iterArgs); + auto origOpVal = origOp->getResults(); + + quantum::QuantumGate adjointOp = + dyn_cast(builder.clone(*origOp)); + adjointOp.setQubitOperands(origOpVal); + adjointOp.setAdjointFlag(!adjointOp.getAdjointFlag()); + auto adjointOpVal = adjointOp->getResults(); + + // Yield the qubits. + builder.create(loc, adjointOpVal); + }) + .getResults(); + + op.setQubitOperands(forVal); + + return WalkResult::advance(); + }); + + // Return the function symbol reference + return SymbolRefAttr::get(rewriter.getContext(), fnFoldedName); +} +FlatSymbolRefAttr ZneLowering::getOrInsertFoldedCircuit(Location loc, PatternRewriter &rewriter, + mitigation::ZneOp op, + Folding foldingAlgorithm) +{ + OpBuilder::InsertionGuard guard(rewriter); + ModuleOp moduleOp = op->getParentOfType(); + std::string fnFoldedName = op.getCallee().str() + ".folded"; + MLIRContext *ctx = rewriter.getContext(); + + if (moduleOp.lookupSymbol(fnFoldedName)) { + return SymbolRefAttr::get(ctx, fnFoldedName); + } + + // Original function + func::FuncOp fnOp = SymbolTable::lookupNearestSymbolFrom(op, op.getCalleeAttr()); + + // Set insertion in the module + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + // Get the number of qubits + const int64_t numberQubits = + (*fnOp.getOps().begin()).getNqubitsAttr().value_or(0); + // Get the device + quantum::DeviceInitOp deviceInitOp = *fnOp.getOps().begin(); + + StringAttr lib = deviceInitOp.getLibAttr(); + StringAttr name = deviceInitOp.getNameAttr(); + StringAttr kwargs = deviceInitOp.getKwargsAttr(); + + TypeRange originalTypes = op.getArgs().getTypes(); + SmallVector typesFolded(originalTypes.begin(), originalTypes.end()); + Type indexType = rewriter.getIndexType(); + typesFolded.push_back(indexType); + + rewriter.setInsertionPointToStart(moduleOp.getBody()); + + FunctionType fnFoldedType = FunctionType::get(ctx, /*inputs=*/ + typesFolded, + /*outputs=*/fnOp.getResultTypes()); + + func::FuncOp fnFoldedOp = rewriter.create(loc, fnFoldedName, fnFoldedType); + fnFoldedOp.setPrivate(); + if (foldingAlgorithm == Folding(1)) { + // Quantum Alloc function + FlatSymbolRefAttr quantumAllocRefAttr = getOrInsertQuantumAlloc(loc, rewriter, op); + func::FuncOp fnAllocOp = + SymbolTable::lookupNearestSymbolFrom(op, quantumAllocRefAttr); + + // Function without measurements: Create function without measurements and with qreg as last + // argument + FlatSymbolRefAttr fnWithoutMeasurementsRefAttr = + getOrInsertFnWithoutMeasurements(loc, rewriter, op); + func::FuncOp fnWithoutMeasurementsOp = + SymbolTable::lookupNearestSymbolFrom(op, fnWithoutMeasurementsRefAttr); + + // Function with measurements: Modify the original function to take a quantum register as + // last arg and keep measurements + FlatSymbolRefAttr fnWithMeasurementsRefAttr = + getOrInsertFnWithMeasurements(loc, rewriter, op); + func::FuncOp fnWithMeasurementsOp = + SymbolTable::lookupNearestSymbolFrom(op, fnWithMeasurementsRefAttr); + + return globalFolding(loc, rewriter, fnFoldedName, lib, name, kwargs, numberQubits, + fnFoldedType, typesFolded, fnFoldedOp, fnAllocOp, + fnWithoutMeasurementsOp, fnWithMeasurementsOp); + } + + rewriter.cloneRegionBefore(fnOp.getBody(), fnFoldedOp.getBody(), fnFoldedOp.end()); + + Block *fnFoldedOpBlock = &fnFoldedOp.getBody().front(); + rewriter.setInsertionPointToStart(fnFoldedOpBlock); + // Loop control variables + Value c0 = rewriter.create(loc, 0); + Value c1 = rewriter.create(loc, 1); + + fnFoldedOpBlock->addArgument(fnFoldedOp.getArgumentTypes().back(), loc); + + if (foldingAlgorithm == Folding(2)) { + return allLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); + } + // Else, if (foldingAlgorithm == Folding(3)): + return randomLocalFolding(rewriter, fnFoldedName, fnFoldedOp, c0, c1); } FlatSymbolRefAttr ZneLowering::getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op) diff --git a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp index 2f9388c36f..32d871e974 100644 --- a/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp +++ b/mlir/lib/Mitigation/Transforms/MitigationMethods/Zne.hpp @@ -32,7 +32,8 @@ struct ZneLowering : public OpRewritePattern { private: static FlatSymbolRefAttr getOrInsertFoldedCircuit(Location loc, PatternRewriter &builder, - mitigation::ZneOp op, Type scalarType); + mitigation::ZneOp op, + Folding foldingAlgorithm); static FlatSymbolRefAttr getOrInsertQuantumAlloc(Location loc, PatternRewriter &rewriter, mitigation::ZneOp op); static FlatSymbolRefAttr diff --git a/mlir/test/Mitigation/ZneFoldingAllFullTest.mlir b/mlir/test/Mitigation/ZneFoldingAllFullTest.mlir new file mode 100644 index 0000000000..2adfd6b4d9 --- /dev/null +++ b/mlir/test/Mitigation/ZneFoldingAllFullTest.mlir @@ -0,0 +1,105 @@ +// Copyright 2023 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func private @circuit.folded(%arg0: index) -> tensor { + // CHECK: [[c0:%.+]] = index.constant 0 + // CHECK: [[c1:%.+]] = index.constant 1 + // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + // CHECK: [[qReg:%.+]] = quantum.alloc( 4) : !quantum.reg + // CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q0_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q0_in:%.+]] = [[q0]]) -> (!quantum.bit) { + // CHECK: [[q0_loop:%.+]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit + // CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] {adjoint} : !quantum.bit + // CHECK: scf.yield [[q0_loop2]] : !quantum.bit + // CHECK: [[q0_out2:%.+]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit + // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[q01_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in1:%.+]] = [[q0_out2]], [[q01_in2:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in1]], [[q01_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q01_loop2]]#0, [[q01_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_out2:%.+]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q2:%.+]] = quantum.extract [[qReg]][ 2] : !quantum.reg -> !quantum.bit + // CHECK: [[q12_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q12_in1:%.+]] = [[q01_out2]]#1, [[q12_in2:%.+]] = [[q2]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q12_loop:%.+]]:2 = quantum.custom "CNOT"() [[q12_in1]], [[q12_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q12_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q12_loop]]#0, [[q12_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q12_loop2]]#0, [[q12_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q12_out2:%.+]]:2 = quantum.custom "CNOT"() [[q12_out]]#0, [[q12_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q1_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q1_in:%.+]] = [[q12_out2]]#0) -> (!quantum.bit) { + // CHECK: [[q1_loop:%.+]] = quantum.custom "T"() [[q1_in]] : !quantum.bit + // CHECK: [[q1_loop2:%.+]] = quantum.custom "T"() [[q1_loop]] {adjoint} : !quantum.bit + // CHECK: scf.yield [[q1_loop2]] : !quantum.bit + // CHECK: [[q1_out2:%.+]] = quantum.custom "T"() [[q1_out]] : !quantum.bit + // CHECK: [[q3:%.+]] = quantum.extract [[qReg]][ 3] : !quantum.reg -> !quantum.bit + // CHECK: [[q23_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q23_in1:%.+]] = [[q12_out2]]#1, [[q23_in2:%.+]] = [[q3]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q23_loop:%.+]]:2 = quantum.custom "CNOT"() [[q23_in1]], [[q23_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q23_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q23_loop]]#0, [[q23_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q23_loop2]]#0, [[q23_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q23_out2:%.+]]:2 = quantum.custom "CNOT"() [[q23_out]]#0, [[q23_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q3_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q3_in:%.+]] = [[q23_out2]]#1) -> (!quantum.bit) { + // CHECK: [[q3_loop:%.+]] = quantum.custom "T"() [[q3_in]] {adjoint} : !quantum.bit + // CHECK: [[q3_loop2:%.+]] = quantum.custom "T"() [[q3_loop]] : !quantum.bit + // CHECK: scf.yield [[q3_loop2]] : !quantum.bit + // CHECK: [[q3_out2:%.+]] = quantum.custom "T"() [[q3_out]] {adjoint} : !quantum.bit + + +//CHECK-LABEL: func.func @circuit() -> tensor attributes {qnode} { +func.func @circuit() -> tensor attributes {qnode} { + quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + %0 = quantum.alloc( 4) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits, %2 : !quantum.bit, !quantum.bit + %3 = quantum.extract %0[ 2] : !quantum.reg -> !quantum.bit + %out_qubits_1:2 = quantum.custom "CNOT"() %out_qubits_0#1, %3 : !quantum.bit, !quantum.bit + %out_qubits_2 = quantum.custom "T"() %out_qubits_1#0 : !quantum.bit + %4 = quantum.extract %0[ 3] : !quantum.reg -> !quantum.bit + %out_qubits_3:2 = quantum.custom "CNOT"() %out_qubits_1#1, %4 : !quantum.bit, !quantum.bit + %out_qubits_4 = quantum.custom "T"() %out_qubits_3#1 {adjoint} : !quantum.bit + %5 = quantum.namedobs %out_qubits_0#0[ PauliY] : !quantum.obs + %6 = quantum.expval %5 {shots = 5 : i64} : f64 + %from_elements = tensor.from_elements %6 : tensor + %7 = quantum.insert %0[ 0], %out_qubits_0#0 : !quantum.reg, !quantum.bit + %8 = quantum.insert %7[ 1], %out_qubits_2 : !quantum.reg, !quantum.bit + %9 = quantum.insert %8[ 2], %out_qubits_3#0 : !quantum.reg, !quantum.bit + %10 = quantum.insert %9[ 3], %out_qubits_4 : !quantum.reg, !quantum.bit + quantum.dealloc %10 : !quantum.reg + quantum.device_release + return %from_elements : tensor + } + +//CHECK-LABEL: func.func @mitigated_circuit() + //CHECK-DAG: [[c0:%.+]] = index.constant 0 + //CHECK-DAG: [[c1:%.+]] = index.constant 1 + //CHECK-DAG: [[c3:%.+]] = index.constant 3 + //CHECK-DAG: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> + //CHECK-DAG: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> + // CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args([[emptyArg:%.+]] = [[emptyRes]]) -> (tensor<3xf64>) { + // CHECK: [[scalarFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> + // CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scalarFactor]]) : (index) -> tensor + // CHECK: [[extracted:%.+]] = tensor.extract [[intermediateRes]][] : tensor + // CHECK: [[from_elements:%.+]] = tensor.from_elements [[extracted]] : tensor<1xf64> + // CHECK: [[resultsFor:%.+]] = scf.for [[idxJ:%.+]] = [[c0]] to [[c1]] step [[c1]] iter_args([[scalarArg:%.+]] = [[emptyArg]]) -> (tensor<3xf64>) { + // CHECK: [[extracted:%.+]] = tensor.extract %from_elements[%arg2] : tensor<1xf64> + // CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg0] : tensor<3xf64> + // CHECK: scf.yield [[insertedRes]] + // CHECK: scf.yield [[resultsFor]] + // CHECK: return [[results]] +func.func @mitigated_circuit() -> tensor<3xf64> { + %numFolds = arith.constant dense<[1, 2, 3]> : tensor<3xindex> + %0 = mitigation.zne @circuit() folding (all) numFolds (%numFolds : tensor<3xindex>) : () -> tensor<3xf64> + func.return %0 : tensor<3xf64> +} diff --git a/mlir/test/Mitigation/ZneFoldingAllMinimalTest.mlir b/mlir/test/Mitigation/ZneFoldingAllMinimalTest.mlir new file mode 100644 index 0000000000..61773fc842 --- /dev/null +++ b/mlir/test/Mitigation/ZneFoldingAllMinimalTest.mlir @@ -0,0 +1,82 @@ +// Copyright 2023 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func private @circuit.folded(%arg0: index) -> tensor { + // CHECK: [[c0:%.+]] = index.constant 0 + // CHECK: [[c1:%.+]] = index.constant 1 + // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + // CHECK: [[qReg:%.+]] = quantum.alloc( 2) : !quantum.reg + // CHECK: [[q0:%.+]] = quantum.extract [[qReg]][ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q0_out:%.+]] = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q0_in:%.+]] = [[q0]]) -> (!quantum.bit) { + // CHECK: [[q0_loop:%.+]] = quantum.custom "Hadamard"() [[q0_in]] : !quantum.bit + // CHECK: [[q0_loop2:%.+]] = quantum.custom "Hadamard"() [[q0_loop]] {adjoint} : !quantum.bit + // CHECK: scf.yield [[q0_loop2]] : !quantum.bit + // CHECK: [[q0_out2:%.+]] = quantum.custom "Hadamard"() [[q0_out]] : !quantum.bit + // CHECK: [[q1:%.+]] = quantum.extract [[qReg]][ 1] : !quantum.reg -> !quantum.bit + // CHECK: [[q01_out:%.+]]:2 = scf.for %arg1 = [[c0]] to %arg0 step [[c1]] iter_args([[q01_in1:%.+]] = [[q0_out2]], [[q01_in2:%.+]] = [[q1]]) -> (!quantum.bit, !quantum.bit) { + // CHECK: [[q01_loop:%.+]]:2 = quantum.custom "CNOT"() [[q01_in1]], [[q01_in2]] : !quantum.bit, !quantum.bit + // CHECK: [[q01_loop2:%.+]]:2 = quantum.custom "CNOT"() [[q01_loop]]#0, [[q01_loop]]#1 {adjoint} : !quantum.bit, !quantum.bit + // CHECK: scf.yield [[q01_loop2]]#0, [[q01_loop2]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q01_out2:%.+]]:2 = quantum.custom "CNOT"() [[q01_out]]#0, [[q01_out]]#1 : !quantum.bit, !quantum.bit + // CHECK: [[q2:%.+]] = quantum.namedobs [[q01_out2]]#0[ PauliY] : !quantum.obs + // CHECK: [[result:%.+]] = quantum.expval [[q2]] : f64 + // CHECK: [[tensorRes:%.+]] = tensor.from_elements [[result]] : tensor + // CHECK: [[q2:%.+]] = quantum.insert %0[ 0], [[q01_out2]]#0 : !quantum.reg, !quantum.bit + // CHECK: [[q3:%.+]] = quantum.insert %7[ 1], [[q01_out2]]#1 : !quantum.reg, !quantum.bit + // CHECK: quantum.dealloc [[q3]] : !quantum.reg + // CHECK: quantum.device_release + // CHECK: return [[tensorRes]] + +//CHECK-LABEL: func.func @circuit() -> tensor attributes {qnode} { +func.func @circuit() -> tensor attributes {qnode} { + quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + %0 = quantum.alloc( 2) : !quantum.reg + %1 = quantum.extract %0[ 0] : !quantum.reg -> !quantum.bit + %out_qubits = quantum.custom "Hadamard"() %1 : !quantum.bit + %2 = quantum.extract %0[ 1] : !quantum.reg -> !quantum.bit + %out_qubits_0:2 = quantum.custom "CNOT"() %out_qubits, %2 : !quantum.bit, !quantum.bit + %3 = quantum.namedobs %out_qubits_0#0[ PauliY] : !quantum.obs + %4 = quantum.expval %3 : f64 + %from_elements = tensor.from_elements %4 : tensor + %5 = quantum.insert %0[ 0], %out_qubits_0#0 : !quantum.reg, !quantum.bit + %6 = quantum.insert %5[ 1], %out_qubits_0#1 : !quantum.reg, !quantum.bit + quantum.dealloc %6 : !quantum.reg + quantum.device_release + return %from_elements : tensor +} + +//CHECK-LABEL: func.func @mitigated_circuit() + //CHECK-DAG: [[c0:%.+]] = index.constant 0 + //CHECK-DAG: [[c1:%.+]] = index.constant 1 + //CHECK-DAG: [[c3:%.+]] = index.constant 3 + //CHECK-DAG: [[emptyRes:%.+]] = tensor.empty() : tensor<3xf64> + //CHECK-DAG: [[dense3:%.+]] = arith.constant dense<[1, 2, 3]> + // CHECK: [[results:%.+]] = scf.for [[idx:%.+]] = [[c0]] to [[c3]] step [[c1]] iter_args([[emptyArg:%.+]] = [[emptyRes]]) -> (tensor<3xf64>) { + // CHECK: [[scalarFactor:%.+]] = tensor.extract [[dense3]][[[idx]]] : tensor<3xindex> + // CHECK: [[intermediateRes:%.+]] = func.call @circuit.folded([[scalarFactor]]) : (index) -> tensor + // CHECK: [[extracted:%.+]] = tensor.extract [[intermediateRes]][] : tensor + // CHECK: [[from_elements:%.+]] = tensor.from_elements [[extracted]] : tensor<1xf64> + // CHECK: [[resultsFor:%.+]] = scf.for [[idxJ:%.+]] = [[c0]] to [[c1]] step [[c1]] iter_args([[scalarArg:%.+]] = [[emptyArg]]) -> (tensor<3xf64>) { + // CHECK: [[extracted:%.+]] = tensor.extract %from_elements[%arg2] : tensor<1xf64> + // CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg3[%arg0] : tensor<3xf64> + // CHECK: scf.yield [[insertedRes]] + // CHECK: scf.yield [[resultsFor]] + // CHECK: return [[results]] +func.func @mitigated_circuit() -> tensor<3xf64> { + %numFolds = arith.constant dense<[1, 2, 3]> : tensor<3xindex> + %0 = mitigation.zne @circuit() folding (all) numFolds (%numFolds : tensor<3xindex>) : () -> tensor<3xf64> + func.return %0 : tensor<3xf64> +} diff --git a/mlir/test/Mitigation/zne.mlir b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir similarity index 63% rename from mlir/test/Mitigation/zne.mlir rename to mlir/test/Mitigation/ZneFoldingGlobalTest.mlir index fe399a9bf1..256fbee72f 100644 --- a/mlir/test/Mitigation/zne.mlir +++ b/mlir/test/Mitigation/ZneFoldingGlobalTest.mlir @@ -14,6 +14,47 @@ // RUN: quantum-opt %s --lower-mitigation --split-input-file --verify-diagnostics | FileCheck %s +// CHECK-LABEL: func.func private @simpleCircuit.folded(%arg0: tensor<3xf64>, %arg1: index) -> f64 { + // CHECK-DAG: [[nQubits:%.+]] = arith.constant 1 + // CHECK-DAG: [[c0:%.+]] = index.constant 0 + // CHECK-DAG: [[c1:%.+]] = index.constant 1 + // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] + // CHECK: [[qReg:%.+]] = call @simpleCircuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg + // CHECK: [[outQregFor:%.+]] = scf.for %arg2 = [[c0]] to %arg1 step [[c1]] iter_args([[inQreg:%.+]] = [[qReg]]) -> (!quantum.reg) { + // CHECK: [[outQreg1:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, [[inQreg]]) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg + // CHECK: [[outQreg2:%.+]] = quantum.adjoint([[outQreg1]]) : !quantum.reg { + // CHECK: ^bb0(%arg4: !quantum.reg): + // CHECK: [[callWithoutMeasurements:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, %arg4) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg + // CHECK: quantum.yield [[callWithoutMeasurements]] : !quantum.reg + // CHECK: scf.yield [[outQreg2]] : !quantum.reg + // CHECK: [[results:%.+]] = call @simpleCircuit.withMeasurements(%arg0, [[outQregFor]]) : (tensor<3xf64>, !quantum.reg) -> f64 + // CHECK: quantum.device_release + // CHECK: return [[results]] + +// CHECK-LABEL: func.func private @simpleCircuit.quantumAlloc(%arg0: i64) -> !quantum.reg { + // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg + // CHECK: return [[allocQreg]] : !quantum.reg + +// CHECK-LABEL: func.func private @simpleCircuit.withoutMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> !quantum.reg { + // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit + // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit + // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit + // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit + // CHECK: return [[q_4]] : !quantum.reg + +// CHECK-LABEL: func.func private @simpleCircuit.withMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> f64 { + // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit + // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit + // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit + // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit + // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit + // CHECK: [[q_5:%.+]] = quantum.namedobs [[q_3]][ PauliX] : !quantum.obs + // CHECK: [[results:%.+]] = quantum.expval [[q_5]] : f64 + // CHECK: quantum.dealloc [[q_4]] : !quantum.reg + // CHECK: return [[results]] : f64 + +// CHECK-LABEL: func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { quantum.device ["rtd_lightning.so", "LightningQubit", "{shots: 0}"] %c0 = arith.constant 0 : index @@ -39,48 +80,6 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { func.return %expval : f64 } -// CHECK: func.func private @simpleCircuit.folded(%arg0: tensor<3xf64>, %arg1: index) -> f64 { - // CHECK-DAG: [[nQubits:%.+]] = arith.constant 1 : i64 - // CHECK-DAG: %idx0 = index.constant 0 - // CHECK-DAG: %idx1 = index.constant 1 - // CHECK: quantum.device["rtd_lightning.so", "LightningQubit", "{shots: 0}"] - // CHECK: [[qReg:%.+]] = call @simpleCircuit.quantumAlloc([[nQubits]]) : (i64) -> !quantum.reg - // CHECK: [[outQregFor:%.+]] = scf.for %arg2 = %idx0 to %arg1 step %idx1 iter_args([[inQreg:%.+]] = [[qReg]]) -> (!quantum.reg) { - // CHECK: [[outQreg1:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, [[inQreg]]) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg - // CHECK: [[outQreg2:%.+]] = quantum.adjoint([[outQreg1]]) : !quantum.reg { - // CHECK: ^bb0(%arg4: !quantum.reg): - // CHECK: [[callWithoutMeasurements:%.+]] = func.call @simpleCircuit.withoutMeasurements(%arg0, %arg4) : (tensor<3xf64>, !quantum.reg) -> !quantum.reg - // CHECK: quantum.yield [[callWithoutMeasurements]] : !quantum.reg - // CHECK: scf.yield [[outQreg2]] : !quantum.reg - // CHECK: [[results:%.+]] = call @simpleCircuit.withMeasurements(%arg0, [[outQregFor]]) : (tensor<3xf64>, !quantum.reg) -> f64 - // CHECK: quantum.device_release - // CHECK: return [[results]] : f64 - -// CHECK: func.func private @simpleCircuit.quantumAlloc(%arg0: i64) -> !quantum.reg { - // CHECK: [[allocQreg:%.+]] = quantum.alloc(%arg0) : !quantum.reg - // CHECK: return [[allocQreg]] : !quantum.reg - -// CHECK: func.func private @simpleCircuit.withoutMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> !quantum.reg { - // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit - // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit - // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit - // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit - // CHECK: return [[q_4]] : !quantum.reg - -// CHECK: func.func private @simpleCircuit.withMeasurements(%arg0: tensor<3xf64>, %arg1: !quantum.reg) -> f64 { - // CHECK: [[q_0:%.+]] = quantum.extract %arg1[ 0] : !quantum.reg -> !quantum.bit - // CHECK: [[q_1:%.+]] = quantum.custom "h"() [[q_0]] : !quantum.bit - // CHECK: [[q_2:%.+]] = quantum.custom "rz"({{.*}}) [[q_1]] : !quantum.bit - // CHECK: [[q_3:%.+]] = quantum.custom "u3"({{.*}}, {{.*}}, {{.*}}) [[q_2]] : !quantum.bit - // CHECK: [[q_4:%.+]] = quantum.insert %arg1[ 0], [[q_3]] : !quantum.reg, !quantum.bit - // CHECK: [[q_5:%.+]] = quantum.namedobs [[q_3]][ PauliX] : !quantum.obs - // CHECK: [[resulst:%.+]] = quantum.expval [[q_5]] : f64 - // CHECK: quantum.dealloc [[q_4]] : !quantum.reg - // CHECK: return [[resulst]] : f64 - -// CHECK: func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { - // CHECK: func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { // CHECK-DAG: [[c0:%.+]] = index.constant 0 // CHECK-DAG: [[c1:%.+]] = index.constant 1 @@ -94,11 +93,11 @@ func.func @simpleCircuit(%arg0: tensor<3xf64>) -> f64 attributes {qnode} { // CHECK: [[resultsFor:%.+]] = scf.for [[idxJ:%.+]] = [[c0]] to [[c1]] step [[c1]] iter_args(%arg4 = %arg2) -> (tensor<5xf64>) { // CHECK: [[extracted:%.+]] = tensor.extract [[tensorRes]][%arg3] : tensor<1xf64> // CHECK: [[insertedRes:%.+]] = tensor.insert [[extracted]] into %arg4[%arg1] : tensor<5xf64> - // CHECK: scf.yield [[insertedRes]] : tensor<5xf64> - // CHECK: scf.yield [[resultsFor]] : tensor<5xf64> - // CHECK: return [[results]] : tensor<5xf64> + // CHECK: scf.yield [[insertedRes]] + // CHECK: scf.yield [[resultsFor]] + // CHECK: return [[results]] func.func @zneCallScalarScalar(%arg0: tensor<3xf64>) -> tensor<5xf64> { - %scaleFactors = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xindex> - %0 = mitigation.zne @simpleCircuit(%arg0) folding (global) scaleFactors (%scaleFactors : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> + %numFolds = arith.constant dense<[1, 2, 3, 4, 5]> : tensor<5xindex> + %0 = mitigation.zne @simpleCircuit(%arg0) folding (global) numFolds (%numFolds : tensor<5xindex>) : (tensor<3xf64>) -> tensor<5xf64> func.return %0 : tensor<5xf64> } diff --git a/setup.py b/setup.py index 4cc52dc4bf..96619f6c18 100644 --- a/setup.py +++ b/setup.py @@ -177,7 +177,12 @@ def run(self): if system_platform == "Linux": custom_calls_extension = Extension( "catalyst.utils.libcustom_calls", - sources=["frontend/catalyst/utils/libcustom_calls.cpp"], + sources=[ + "frontend/catalyst/utils/libcustom_calls.cpp", + "frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp", + "frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp", + ], + extra_compile_args=["-std=c++17"], ) cmdclass = {"build_ext": CustomBuildExtLinux} @@ -189,7 +194,12 @@ def run(self): variables["LDCXXSHARED"] = variables["LDCXXSHARED"].replace("-bundle", "-dynamiclib") custom_calls_extension = Extension( "catalyst.utils.libcustom_calls", - sources=["frontend/catalyst/utils/libcustom_calls.cpp"], + sources=[ + "frontend/catalyst/utils/libcustom_calls.cpp", + "frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp", + "frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels_using_lapack.cpp", + ], + extra_compile_args=["-std=c++17"], ) cmdclass = {"build_ext": CustomBuildExtMacos}