diff --git a/.github/workflows/check-pl-compat.yaml b/.github/workflows/check-pl-compat.yaml
index 37fbaf9588..bfb5102c51 100644
--- a/.github/workflows/check-pl-compat.yaml
+++ b/.github/workflows/check-pl-compat.yaml
@@ -39,7 +39,7 @@ jobs:
git checkout $(git tag | sort -V | tail -1)
- if: ${{ inputs.catalyst == 'release-candidate' }}
run: |
- git checkout v0.8.0-rc
+ git checkout v0.8.1-rc
- name: Install deps
run: |
diff --git a/doc/dev/jax_integration.rst b/doc/dev/jax_integration.rst
index cbbbd733b6..4d427977cb 100644
--- a/doc/dev/jax_integration.rst
+++ b/doc/dev/jax_integration.rst
@@ -88,7 +88,6 @@ that doesn't work with Catalyst includes:
- ``jax.numpy.polyfit``
- ``jax.numpy.fft``
-- ``jax.scipy.linalg``
- ``jax.numpy.ndarray.at[index]`` when ``index`` corresponds to all array
indices.
diff --git a/doc/releases/changelog-0.8.1.md b/doc/releases/changelog-0.8.1.md
index f8b7d5b94c..62ac8ac216 100644
--- a/doc/releases/changelog-0.8.1.md
+++ b/doc/releases/changelog-0.8.1.md
@@ -2,14 +2,91 @@
New features
+* The `catalyst.mitigate_with_zne` error mitigation compilation pass now supports
+ the option to fold gates locally as well as the existing method of globally.
+ [(#1006)](https://github.com/PennyLaneAI/catalyst/pull/1006)
+ [(#1129)](https://github.com/PennyLaneAI/catalyst/pull/1129)
+
+ While global folding applies the scale factor by forming the inverse of the
+ entire quantum circuit (without measurements) and repeating
+ the circuit with its inverse, local folding instead inserts per-gate folding sequences directly in place
+ of each gate in the original circuit.
+
+ For example,
+
+ ```python
+ import jax
+ import pennylane as qml
+ from catalyst import qjit, mitigate_with_zne
+ from pennylane.transforms import exponential_extrapolate
+
+ dev = qml.device("lightning.qubit", wires=4, shots=5)
+
+ @qml.qnode(dev)
+ def circuit():
+ qml.Hadamard(wires=0)
+ qml.CNOT(wires=[0, 1])
+ return qml.expval(qml.PauliY(wires=0))
+
+ @qjit(keep_intermediate=True)
+ def mitigated_circuit():
+ s = jax.numpy.array([1, 2, 3])
+ return mitigate_with_zne(
+ circuit,
+ scale_factors=s,
+ extrapolate=exponential_extrapolate,
+ folding="local-all" # "local-all" for local on all gates or "global" for the original method (default being "global")
+ )()
+ ```
+
+ ```pycon
+ >>> circuit()
+ >>> mitigated_circuit()
+ ```
+
Improvements
+* Fixes an issue where certain JAX linear algebra functions from `jax.scipy.linalg` gave incorrect
+ results when invoked from within a qjit block, and adds full support for other `jax.scipy.linalg`
+ functions.
+ [(#1097)](https://github.com/PennyLaneAI/catalyst/pull/1097)
+
+ The supported linear algebra functions include, but are not limited to:
+
+ - [`jax.scipy.linalg.cholesky`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cholesky.html)
+ - [`jax.scipy.linalg.expm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.expm.html)
+ - [`jax.scipy.linalg.funm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.funm.html)
+ - [`jax.scipy.linalg.hessenberg`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.hessenberg.html)
+ - [`jax.scipy.linalg.lu`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu.html)
+ - [`jax.scipy.linalg.lu_solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.lu_solve.html)
+ - [`jax.scipy.linalg.polar`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.polar.html)
+ - [`jax.scipy.linalg.qr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.qr.html)
+ - [`jax.scipy.linalg.schur`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.schur.html)
+ - [`jax.scipy.linalg.solve`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.solve.html)
+ - [`jax.scipy.linalg.sqrtm`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.sqrtm.html)
+ - [`jax.scipy.linalg.svd`](https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.svd.html)
+
Breaking changes
-Deprecations
+* The argument `scale_factors` of `mitigate_with_zne` function now follows the proper literature
+ definition. It now needs to be a list of positive odd integers, as we don't support the fractional
+ part.
+ [(#1120)](https://github.com/PennyLaneAI/catalyst/pull/1120)
Bug fixes
+* Those functions calling the `gather_p` primitive (like `jax.scipy.linalg.expm`)
+ can now be used in multiple qjits in a single program.
+ [(#1096)](https://github.com/PennyLaneAI/catalyst/pull/1096)
+
Contributors
This release contains contributions from (in alphabetical order):
+
+Joey Carter,
+Alessandro Cosentino,
+Paul Haochen Wang,
+David Ittah,
+Romain Moyard,
+Daniel Strano,
+Raul Torres.
diff --git a/frontend/catalyst/_version.py b/frontend/catalyst/_version.py
index 3a1257d76c..13fef47bd0 100644
--- a/frontend/catalyst/_version.py
+++ b/frontend/catalyst/_version.py
@@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""
-__version__ = "0.8.0"
+__version__ = "0.8.1"
diff --git a/frontend/catalyst/api_extensions/error_mitigation.py b/frontend/catalyst/api_extensions/error_mitigation.py
index 84c1982548..c3b67b11d7 100644
--- a/frontend/catalyst/api_extensions/error_mitigation.py
+++ b/frontend/catalyst/api_extensions/error_mitigation.py
@@ -30,6 +30,10 @@
from catalyst.jax_primitives import Folding, zne_p
+def _is_odd_positive(numbers_list):
+ return all(isinstance(i, int) and i > 0 and i % 2 != 0 for i in numbers_list)
+
+
## API ##
def mitigate_with_zne(
fn=None, *, scale_factors=None, extrapolate=None, extrapolate_kwargs=None, folding="global"
@@ -47,7 +51,7 @@ def mitigate_with_zne(
Args:
fn (qml.QNode): the circuit to be mitigated.
- scale_factors (array[int]): the range of noise scale factors used.
+ scale_factors (list[int]): the range of noise scale factors used.
extrapolate (Callable): A qjit-compatible function taking two sequences as arguments (scale
factors, and results), and returning a float by performing a fitting procedure.
By default, perfect polynomial fitting :func:`~.polynomial_extrapolate` will be used,
@@ -56,6 +60,7 @@ def mitigate_with_zne(
function.
folding (str): Unitary folding technique to be used to scale the circuit. Possible values:
- global: the global unitary of the input circuit is folded
+ - local-all: per-gate folding sequences replace original gates in-place in the circuit
Returns:
Callable: A callable object that computes the mitigated of the wrapped :class:`~.QNode`
@@ -113,10 +118,11 @@ def workflow(weights, s):
return zne_circuit(weights)
>>> weights = jnp.ones([3, 2, 3])
- >>> scale_factors = jnp.array([1, 2, 3])
+ >>> scale_factors = [1, 3, 5]
>>> workflow(weights, scale_factors)
Array(-0.19946598, dtype=float64)
"""
+
kwargs = copy.copy(locals())
kwargs.pop("fn")
@@ -128,7 +134,12 @@ def workflow(weights, s):
elif extrapolate_kwargs is not None:
extrapolate = functools.partial(extrapolate, **extrapolate_kwargs)
- return ZNE(fn, scale_factors, extrapolate, folding)
+ if not _is_odd_positive(scale_factors):
+ raise ValueError("The scale factors must be positive odd integers: {scale_factors}")
+
+ num_folds = jnp.array([jnp.floor((s - 1) / 2) for s in scale_factors], dtype=int)
+
+ return ZNE(fn, num_folds, extrapolate, folding)
## IMPL ##
@@ -147,7 +158,7 @@ class ZNE:
def __init__(
self,
fn: Callable,
- scale_factors: jnp.ndarray,
+ num_folds: jnp.ndarray,
extrapolate: Callable[[Sequence[float], Sequence[float]], float],
folding: str,
):
@@ -155,7 +166,7 @@ def __init__(
raise TypeError(f"A QNode is expected, got the classical function {fn}")
self.fn = fn
self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}"
- self.scale_factors = scale_factors
+ self.num_folds = num_folds
self.extrapolate = extrapolate
self.folding = folding
@@ -175,14 +186,12 @@ def __call__(self, *args, **kwargs):
except ValueError as e:
raise ValueError(f"Folding type must be one of {list(map(str, Folding))}") from e
# TODO: remove the following check once #755 is completed
- if folding != Folding.GLOBAL:
+ if folding == Folding.RANDOM:
raise NotImplementedError(f"Folding type {folding.value} is being developed")
- results = zne_p.bind(
- *args_data, self.scale_factors, folding=folding, jaxpr=jaxpr, fn=self.fn
- )
- float_scale_factors = jnp.array(self.scale_factors, dtype=float)
- results = self.extrapolate(float_scale_factors, results[0])
+ results = zne_p.bind(*args_data, self.num_folds, folding=folding, jaxpr=jaxpr, fn=self.fn)
+ float_num_folds = jnp.array(self.num_folds, dtype=float)
+ results = self.extrapolate(float_num_folds, results[0])
# Single measurement
if results.shape == ():
return results
diff --git a/frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py b/frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py
deleted file mode 100644
index 5029e4d508..0000000000
--- a/frontend/catalyst/jax_extras/jax_scipy_linalg_warnings.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright 2022-2024 Xanadu Quantum Technologies Inc.
-
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-
-# http://www.apache.org/licenses/LICENSE-2.0
-
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-This module contains warnings for using jax.scipy.linalg functions inside qjit.
-Due to improperly linked lapack symbols, occasionally these functions give wrong
-numerical results when used in a qjit context.
-As for now, we warn users to wrap all of these with a catalyst.accelerate() callback.
-This patch should be removed when we have proper linkage to lapack.
-See:
- https://app.shortcut.com/xanaduai/story/70899/find-a-system-to-automatically-create-a-custom-call-library-from-the-one-in-jax
- https://github.com/PennyLaneAI/catalyst/issues/753
- https://github.com/PennyLaneAI/catalyst/issues/1071
-"""
-
-import warnings
-
-import jax
-
-from catalyst.tracing.contexts import AccelerateContext
-
-
-class JaxLinalgWarner:
- def __init__(self, fn):
- self.fn = fn
-
- def __call__(self, *args, **kwargs):
- if not AccelerateContext.am_inside_accelerate():
- warnings.warn(
- f"""
- jax.scipy.linalg.{self.fn.__name__} occasionally gives wrong numerical results
- when used within a qjit-compiled function.
- See https://github.com/PennyLaneAI/catalyst/issues/1071.
- In the meantime, we recommend catalyst.accelerate to call
- the underlying {self.fn.__name__} function directly:
-
- @qjit
- def f(A):
- return catalyst.accelerate(jax.scipy.linalg.{self.fn.__name__})(A)
-
- See https://docs.pennylane.ai/projects/catalyst/en/latest/code/api/catalyst.accelerate.html
- """
- )
- return (self.fn)(*args, **kwargs)
diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py
index 2ad97c4e04..f2c2e93cde 100644
--- a/frontend/catalyst/jax_extras/patches.py
+++ b/frontend/catalyst/jax_extras/patches.py
@@ -20,11 +20,14 @@
import jax
from jax._src.lax.lax import _nary_lower_hlo
from jax._src.lax.slicing import (
+ _argnum_weak_type,
+ _gather_dtype_rule,
_gather_shape_computation,
_is_sorted,
_no_duplicate_dims,
_rank,
_sorted_dims_in_range,
+ standard_primitive,
)
from jax._src.lib.mlir.dialects import hlo
from jax.core import AbstractValue, Tracer, concrete_aval
@@ -35,6 +38,7 @@
"_gather_shape_rule_dynamic",
"_sin_lowering2",
"_cos_lowering2",
+ "gather2_p",
)
@@ -186,6 +190,16 @@ def _gather_shape_rule_dynamic(
return _gather_shape_computation(indices, dimension_numbers, slice_sizes)
+# TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is
+# applied.
+gather2_p = standard_primitive(
+ _gather_shape_rule_dynamic,
+ _gather_dtype_rule,
+ "gather",
+ weak_type_rule=_argnum_weak_type(0),
+)
+
+
def _sin_lowering2(ctx, x):
"""Use hlo.sine lowering instead of the new sin lowering from jax 0.4.28"""
return _nary_lower_hlo(hlo.sine, ctx, x)
diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py
index 97db7386ab..01fbcefd04 100644
--- a/frontend/catalyst/jax_extras/tracing.py
+++ b/frontend/catalyst/jax_extras/tracing.py
@@ -45,12 +45,7 @@
)
from jax._src.lax.control_flow import _initial_style_jaxpr
from jax._src.lax.lax import _abstractify, cos_p, sin_p
-from jax._src.lax.slicing import (
- _argnum_weak_type,
- _gather_dtype_rule,
- _gather_lower,
- standard_primitive,
-)
+from jax._src.lax.slicing import _gather_lower
from jax._src.linear_util import annotate
from jax._src.pjit import _extract_implicit_args, _flat_axes_specs
from jax._src.source_info_util import current as jax_current
@@ -99,8 +94,8 @@
from catalyst.jax_extras.patches import (
_cos_lowering2,
- _gather_shape_rule_dynamic,
_sin_lowering2,
+ gather2_p,
get_aval2,
)
from catalyst.logging import debug_logger
@@ -514,14 +509,6 @@ def abstractify(args, kwargs):
in_type = infer_lambda_input_type(axes_specs, flat_args)
return in_type, in_tree
- # TODO: See the `_gather_shape_rule_dynamic` comment. Remove once the upstream change is
- # applied.
- gather2_p = standard_primitive(
- _gather_shape_rule_dynamic,
- _gather_dtype_rule,
- "gather",
- weak_type_rule=_argnum_weak_type(0),
- )
register_lowering(gather2_p, _gather_lower)
# TBD
diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py
index 649c4d4a75..ec38c0551e 100644
--- a/frontend/catalyst/jax_primitives.py
+++ b/frontend/catalyst/jax_primitives.py
@@ -227,8 +227,8 @@ class Folding(Enum):
"""
GLOBAL = "global"
- RANDOM = "random"
- ALL = "all"
+ RANDOM = "local-random"
+ ALL = "local-all"
##############
@@ -930,7 +930,7 @@ def _folding_attribute(ctx, folding):
ctx = ctx.module_context.context
return ir.OpaqueAttr.get(
"mitigation",
- ("folding " + Folding(folding).value).encode("utf-8"),
+ ("folding " + Folding(folding).name.lower()).encode("utf-8"),
ir.NoneType.get(ctx),
ctx,
)
@@ -950,13 +950,13 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
symbol_name = func_op.name.value
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
- scale_factors = args[-1]
+ num_folds = args[-1]
return ZneOp(
flat_output_types,
ir.FlatSymbolRefAttr.get(symbol_name),
mlir.flatten_lowering_ir_args(args[0:-1]),
_folding_attribute(ctx, folding),
- scale_factors,
+ num_folds,
).results
diff --git a/frontend/catalyst/jit.py b/frontend/catalyst/jit.py
index 88209b4d8c..7b2ae5d8a8 100644
--- a/frontend/catalyst/jit.py
+++ b/frontend/catalyst/jit.py
@@ -35,7 +35,6 @@
from catalyst.compiled_functions import CompilationCache, CompiledFunction
from catalyst.compiler import CompileOptions, Compiler
from catalyst.debug.instruments import instrument
-from catalyst.jax_extras.jax_scipy_linalg_warnings import JaxLinalgWarner
from catalyst.jax_tracer import lower_jaxpr_to_mlir, trace_to_jaxpr
from catalyst.logging import debug_logger, debug_logger_init
from catalyst.passes import _inject_transform_named_sequence
@@ -590,15 +589,6 @@ def closure(qnode, *args, **kwargs):
with Patcher(
(qml.QNode, "__call__", closure),
- # !!! TODO: fix jax.scipy numerical failures with properly fetched lapack calls
- # As of now, we raise a warning prompting the user to use a callback with catalyst.accelerate()
- # https://app.shortcut.com/xanaduai/story/70899/find-a-system-to-automatically-create-a-custom-call-library-from-the-one-in-jax
- # https://github.com/PennyLaneAI/catalyst/issues/753
- # https://github.com/PennyLaneAI/catalyst/issues/1071
- (jax.scipy.linalg, "expm", JaxLinalgWarner(jax.scipy.linalg.expm)),
- (jax.scipy.linalg, "lu", JaxLinalgWarner(jax.scipy.linalg.lu)),
- (jax.scipy.linalg, "lu_factor", JaxLinalgWarner(jax.scipy.linalg.lu_factor)),
- (jax.scipy.linalg, "lu_solve", JaxLinalgWarner(jax.scipy.linalg.lu_solve)),
):
# TODO: improve PyTree handling
diff --git a/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp
new file mode 100644
index 0000000000..b150eb34e2
--- /dev/null
+++ b/frontend/catalyst/utils/jax_cpu_lapack_kernels/lapack_kernels.cpp
@@ -0,0 +1,941 @@
+// Copyright 2024 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// This file has been modified from its original form in the JAX project at
+// https://github.com/google/jax released under the Apache License, Version 2.0,
+// with the following copyright notice:
+
+// Copyright 2021 The JAX Authors.
+
+/*
+ * This file is a modified version of
+ *
+ * https://github.com/google/jax/blob/jaxlib-v0.4.28/jaxlib/cpu/lapack_kernels.cc
+ *
+ * from jaxlib-v0.4.28.
+ *
+ * See note in lapack_kernels.h for a high-level explanation of the
+ * modifications and the motivation for them. Specifically, the modifications
+ * made in this file include:
+ *
+ * 1. Used the C interfaces to the BLAS and LAPACK routines instead of the
+ * FORTRAN interfaces, and always use row-major matrix layout. This
+ * modification generally involves the following:
+ * - Adding the matrix layout parameter as the first argument to the BLAS/
+ * LAPACK call, either `CblasRowMajor` for BLAS or `LAPACK_ROW_MAJOR`
+ * for LAPACK.
+ * - Specifying the array leading dimensions (e.g. `lda`) such that they
+ * are dependent upon the matrix layout, rather than hard-coding them.
+ * Note that these should always evaluate to the value required for
+ * row-major matrix layout (typically the number of columns n of the
+ * matrix).
+ * - Remove parameters used by the FORTRAN interfaces but not by the C
+ * interfaces, e.g. the workspace array parameters `lwork`, `rwork`,
+ * `iwork`, etc.
+ * 2. Guarded the #include of the ABSEIL `dynamic_annotations.h header by the
+ * `USE_ABSEIL_LIB` macro and the uses of `ABSL_ANNOTATE_MEMORY_IS_INITIALIZED`,
+ * since they are not needed for Catalyst.
+ * 3. Opportunistically improved const-correctness.
+ * 4. Applied Catalyst C++ code formatting.
+ */
+
+#include "lapack_kernels.hpp"
+
+#include
+#include
+#include
+#include
+
+#ifdef USE_ABSEIL_LIB
+#include "absl/base/dynamic_annotations.h"
+#endif
+
+namespace {
+
+inline int64_t catch_lapack_int_overflow(const std::string &source, int64_t value)
+{
+ if constexpr (sizeof(jax::lapack_int) == sizeof(int64_t)) {
+ return value;
+ }
+ else {
+ if (value > std::numeric_limits::max()) {
+ throw std::overflow_error(source + "(=" + std::to_string(value) +
+ ") exceeds maximum value of jax::lapack_int");
+ }
+ return value;
+ }
+}
+
+} // namespace
+
+namespace jax {
+
+static_assert(sizeof(lapack_int) == sizeof(int32_t), "Expected LAPACK integers to be 32-bit");
+
+// Trsm
+// ~~~~
+
+template typename RealTrsm::FnType *RealTrsm::fn = nullptr;
+
+template void RealTrsm::Kernel(void *out, void **data, XlaCustomCallStatus *)
+{
+ const int32_t left_side = *reinterpret_cast(data[0]);
+ const int32_t lower = *reinterpret_cast(data[1]);
+ const int32_t trans_a = *reinterpret_cast(data[2]);
+ const int32_t diag = *reinterpret_cast(data[3]);
+ const int m = *reinterpret_cast(data[4]);
+ const int n = *reinterpret_cast(data[5]);
+ const int batch = *reinterpret_cast(data[6]);
+ const T alpha = *reinterpret_cast(data[7]);
+ const T *a = reinterpret_cast(data[8]);
+ T *b = reinterpret_cast(data[9]);
+
+ T *x = reinterpret_cast(out);
+ if (x != b) {
+ std::memcpy(x, b,
+ static_cast(batch) * static_cast(m) *
+ static_cast(n) * sizeof(T));
+ }
+
+ constexpr CBLAS_ORDER corder = CblasRowMajor;
+ const CBLAS_SIDE cside = left_side ? CblasLeft : CblasRight;
+ const CBLAS_UPLO cuplo = lower ? CblasLower : CblasUpper;
+ const CBLAS_TRANSPOSE ctransa = (trans_a == 1) ? CblasTrans
+ : (trans_a == 2) ? CblasConjTrans
+ : CblasNoTrans;
+ const CBLAS_DIAG cdiag = diag ? CblasUnit : CblasNonUnit;
+ const int lda = left_side ? m : n;
+ const int ldb = (corder == CblasColMajor) ? m : n; // Note: m if col-major, n if row-major
+
+ const int64_t x_plus = static_cast(m) * static_cast(n);
+ const int64_t a_plus = static_cast(lda) * static_cast(lda);
+
+ for (int i = 0; i < batch; ++i) {
+ fn(CblasRowMajor, cside, cuplo, ctransa, cdiag, m, n, alpha, a, lda, x, ldb);
+ x += x_plus;
+ a += a_plus;
+ }
+}
+
+template typename ComplexTrsm::FnType *ComplexTrsm::fn = nullptr;
+
+template void ComplexTrsm::Kernel(void *out, void **data, XlaCustomCallStatus *)
+{
+ const int32_t left_side = *reinterpret_cast(data[0]);
+ const int32_t lower = *reinterpret_cast(data[1]);
+ const int32_t trans_a = *reinterpret_cast(data[2]);
+ const int32_t diag = *reinterpret_cast(data[3]);
+ const int m = *reinterpret_cast(data[4]);
+ const int n = *reinterpret_cast(data[5]);
+ const int batch = *reinterpret_cast(data[6]);
+ const T *alpha = reinterpret_cast(data[7]);
+ const T *a = reinterpret_cast(data[8]);
+ T *b = reinterpret_cast(data[9]);
+
+ T *x = reinterpret_cast(out);
+ if (x != b) {
+ std::memcpy(x, b,
+ static_cast(batch) * static_cast(m) *
+ static_cast(n) * sizeof(T));
+ }
+
+ constexpr CBLAS_ORDER corder = CblasRowMajor;
+ const CBLAS_SIDE cside = left_side ? CblasLeft : CblasRight;
+ const CBLAS_UPLO cuplo = lower ? CblasLower : CblasUpper;
+ const CBLAS_TRANSPOSE ctransa = (trans_a == 1) ? CblasTrans
+ : (trans_a == 2) ? CblasConjTrans
+ : CblasNoTrans;
+ const CBLAS_DIAG cdiag = diag ? CblasUnit : CblasNonUnit;
+ const int lda = left_side ? m : n;
+ const int ldb = (corder == CblasColMajor) ? m : n; // Note: m if col-major, n if row-major
+
+ const int64_t x_plus = static_cast(m) * static_cast(n);
+ const int64_t a_plus = static_cast(lda) * static_cast(lda);
+
+ for (int i = 0; i < batch; ++i) {
+ fn(CblasRowMajor, cside, cuplo, ctransa, cdiag, m, n, alpha, a, lda, x, ldb);
+ x += x_plus;
+ a += a_plus;
+ }
+}
+
+template struct RealTrsm;
+template struct RealTrsm;
+template struct ComplexTrsm>;
+template struct ComplexTrsm>;
+
+// Getrf
+// ~~~~~
+
+template typename Getrf::FnType *Getrf::fn = nullptr;
+
+template void Getrf::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int b = *(reinterpret_cast(data[0]));
+ const int m = *(reinterpret_cast(data[1]));
+ const int n = *(reinterpret_cast(data[2]));
+ const T *a_in = reinterpret_cast(data[3]);
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ int *ipiv = reinterpret_cast(out[1]);
+ int *info = reinterpret_cast(out[2]);
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(m) * static_cast(n) *
+ sizeof(T));
+ }
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+ const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m;
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(corder, m, n, a_out, lda, ipiv);
+ a_out += static_cast(m) * static_cast(n);
+ ipiv += std::min(m, n);
+ ++info;
+ }
+}
+
+template struct Getrf;
+template struct Getrf;
+template struct Getrf>;
+template struct Getrf>;
+
+// Geqrf
+// ~~~~~
+
+template typename Geqrf::FnType *Geqrf::fn = nullptr;
+
+template void Geqrf::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int b = *(reinterpret_cast(data[0]));
+ const int m = *(reinterpret_cast(data[1]));
+ const int n = *(reinterpret_cast(data[2]));
+ const int lwork = *(reinterpret_cast(data[3]));
+ const T *a_in = reinterpret_cast(data[4]);
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ T *tau = reinterpret_cast(out[1]);
+ int *info = reinterpret_cast(out[2]);
+ T *work = reinterpret_cast(out[3]);
+
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(m) * static_cast(n) *
+ sizeof(T));
+ }
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+ const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m;
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(LAPACK_ROW_MAJOR, m, n, a_out, lda, tau);
+ a_out += static_cast(m) * static_cast(n);
+ tau += std::min(m, n);
+ ++info;
+ }
+}
+
+template struct Geqrf;
+template struct Geqrf;
+template struct Geqrf>;
+template struct Geqrf>;
+
+// Orgqr
+// ~~~~~
+
+template typename Orgqr::FnType *Orgqr::fn = nullptr;
+
+template void Orgqr::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int b = *(reinterpret_cast(data[0]));
+ const int m = *(reinterpret_cast(data[1]));
+ const int n = *(reinterpret_cast(data[2]));
+ const int k = *(reinterpret_cast(data[3]));
+ const int lwork = *(reinterpret_cast(data[4]));
+ const T *a_in = reinterpret_cast(data[5]);
+ T *tau = reinterpret_cast(data[6]);
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ int *info = reinterpret_cast(out[1]);
+ T *work = reinterpret_cast(out[2]);
+
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(m) * static_cast(n) *
+ sizeof(T));
+ }
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+ const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m;
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(LAPACK_ROW_MAJOR, m, n, k, a_out, lda, tau);
+ a_out += static_cast(m) * static_cast(n);
+ tau += k;
+ ++info;
+ }
+}
+
+template struct Orgqr;
+template struct Orgqr;
+template struct Orgqr>;
+template struct Orgqr>;
+
+// Potrf
+// ~~~~~
+
+template typename Potrf::FnType *Potrf::fn = nullptr;
+
+template void Potrf::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int32_t lower = *(reinterpret_cast(data[0]));
+ const int b = *(reinterpret_cast(data[1]));
+ const int n = *(reinterpret_cast(data[2]));
+ const T *a_in = reinterpret_cast(data[3]);
+ const char uplo = lower ? 'L' : 'U';
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ int *info = reinterpret_cast(out[1]);
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(n) * static_cast(n) *
+ sizeof(T));
+ }
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(corder, uplo, n, a_out, n);
+ a_out += static_cast(n) * static_cast(n);
+ ++info;
+ }
+}
+
+template struct Potrf;
+template struct Potrf;
+template struct Potrf>;
+template struct Potrf>;
+
+// Gesdd
+
+static char GesddJobz(bool job_opt_compute_uv, bool job_opt_full_matrices)
+{
+ if (!job_opt_compute_uv) {
+ return 'N';
+ }
+ else if (!job_opt_full_matrices) {
+ return 'S';
+ }
+ return 'A';
+}
+
+static int Gesdd_ldu(const int order, const char jobz, const int m, const int n)
+{
+ int ldu = 0;
+ if (jobz == 'N') {
+ ldu = 1;
+ }
+ else if (jobz == 'A') {
+ ldu = m;
+ }
+ else if (jobz == 'S') {
+ if (m >= n) {
+ ldu = (order == LAPACK_ROW_MAJOR) ? n : m;
+ }
+ else {
+ ldu = m;
+ }
+ }
+ return ldu;
+}
+
+static int Gesdd_ldvt(const int order, const char jobz, const int m, const int n)
+{
+ int ldu = 0;
+ if (jobz == 'N') {
+ ldu = 1;
+ }
+ else if (jobz == 'A') {
+ ldu = n;
+ }
+ else if (jobz == 'S') {
+ if (m >= n) {
+ ldu = n;
+ }
+ else {
+ ldu = (order == LAPACK_ROW_MAJOR) ? n : m;
+ }
+ }
+ return ldu;
+}
+
+template typename RealGesdd::FnType *RealGesdd::fn = nullptr;
+
+template void RealGesdd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int32_t job_opt_full_matrices = *(reinterpret_cast(data[0]));
+ const int32_t job_opt_compute_uv = *(reinterpret_cast(data[1]));
+ const int b = *(reinterpret_cast(data[2]));
+ const int m = *(reinterpret_cast(data[3]));
+ const int n = *(reinterpret_cast(data[4]));
+ const int lwork = *(reinterpret_cast(data[5]));
+ T *a_in = reinterpret_cast(data[6]);
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ T *s = reinterpret_cast(out[1]);
+ T *u = reinterpret_cast(out[2]);
+ T *vt = reinterpret_cast(out[3]);
+ int *info = reinterpret_cast(out[4]);
+ int *iwork = reinterpret_cast(out[5]);
+ T *work = reinterpret_cast(out[6]);
+
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(m) * static_cast(n) *
+ sizeof(T));
+ }
+
+ const char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices);
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+ const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m;
+ const int ldu = Gesdd_ldu(corder, jobz, m, n);
+ const int tdu = ldu;
+ const int ldvt = Gesdd_ldvt(corder, jobz, m, n);
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(corder, jobz, m, n, a_out, lda, s, u, ldu, vt, ldvt);
+ a_out += static_cast(m) * n;
+ s += std::min(m, n);
+ u += static_cast(m) * tdu;
+ vt += static_cast(ldvt) * n;
+ ++info;
+ }
+}
+
+template typename ComplexGesdd::FnType *ComplexGesdd::fn = nullptr;
+
+template
+void ComplexGesdd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int32_t job_opt_full_matrices = *(reinterpret_cast(data[0]));
+ const int32_t job_opt_compute_uv = *(reinterpret_cast(data[1]));
+ const int b = *(reinterpret_cast(data[2]));
+ const int m = *(reinterpret_cast(data[3]));
+ const int n = *(reinterpret_cast(data[4]));
+ const int lwork = *(reinterpret_cast(data[5]));
+ T *a_in = reinterpret_cast(data[6]);
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ typename T::value_type *s = reinterpret_cast(out[1]);
+ T *u = reinterpret_cast(out[2]);
+ T *vt = reinterpret_cast(out[3]);
+ int *info = reinterpret_cast(out[4]);
+ int *iwork = reinterpret_cast(out[5]);
+ typename T::value_type *rwork = reinterpret_cast(out[6]);
+ T *work = reinterpret_cast(out[7]);
+
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(m) * static_cast(n) *
+ sizeof(T));
+ }
+
+ const char jobz = GesddJobz(job_opt_compute_uv, job_opt_full_matrices);
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+ const int lda = (corder == LAPACK_ROW_MAJOR) ? n : m;
+ const int ldu = Gesdd_ldu(corder, jobz, m, n);
+ const int tdu = ldu;
+ const int ldvt = Gesdd_ldvt(corder, jobz, m, n);
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(LAPACK_ROW_MAJOR, jobz, m, n, a_out, lda, s, u, ldu, vt, ldvt);
+ a_out += static_cast(m) * n;
+ s += std::min(m, n);
+ u += static_cast(m) * tdu;
+ vt += static_cast(ldvt) * n;
+ ++info;
+ }
+}
+
+template struct RealGesdd;
+template struct RealGesdd;
+template struct ComplexGesdd>;
+template struct ComplexGesdd>;
+
+// Syevd/Heevd
+// ~~~~~~~~~~~
+
+template typename RealSyevd::FnType *RealSyevd::fn = nullptr;
+
+template void RealSyevd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int32_t lower = *(reinterpret_cast(data[0]));
+ const int b = *(reinterpret_cast(data[1]));
+ const int n = *(reinterpret_cast(data[2]));
+ const T *a_in = reinterpret_cast(data[3]);
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ T *w_out = reinterpret_cast(out[1]);
+ int *info = reinterpret_cast(out[2]);
+ T *work = reinterpret_cast(out[3]);
+ int *iwork = reinterpret_cast(out[4]);
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(n) * static_cast(n) *
+ sizeof(T));
+ }
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+ const char jobz = 'V';
+ const char uplo = lower ? 'L' : 'U';
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(corder, jobz, uplo, n, a_out, n, w_out);
+ a_out += static_cast(n) * n;
+ w_out += n;
+ ++info;
+ }
+}
+
+template typename ComplexHeevd::FnType *ComplexHeevd::fn = nullptr;
+
+template
+void ComplexHeevd::Kernel(void *out_tuple, void **data, XlaCustomCallStatus *)
+{
+ const int32_t lower = *(reinterpret_cast(data[0]));
+ const int b = *(reinterpret_cast(data[1]));
+ const int n = *(reinterpret_cast(data[2]));
+ const T *a_in = reinterpret_cast(data[3]);
+
+ void **out = reinterpret_cast(out_tuple);
+ T *a_out = reinterpret_cast(out[0]);
+ typename T::value_type *w_out = reinterpret_cast(out[1]);
+ int *info = reinterpret_cast(out[2]);
+ T *work = reinterpret_cast(out[3]);
+ typename T::value_type *rwork = reinterpret_cast(out[4]);
+ int *iwork = reinterpret_cast(out[5]);
+ if (a_out != a_in) {
+ std::memcpy(a_out, a_in,
+ static_cast(b) * static_cast(n) * static_cast(n) *
+ sizeof(T));
+ }
+
+ constexpr int corder = LAPACK_ROW_MAJOR;
+ const char jobz = 'V';
+ const char uplo = lower ? 'L' : 'U';
+
+ for (int i = 0; i < b; ++i) {
+ *info = fn(corder, jobz, uplo, n, a_out, n, w_out);
+ a_out += static_cast(n) * n;
+ w_out += n;
+ ++info;
+ }
+}
+
+template struct RealSyevd;
+template struct RealSyevd;
+template struct ComplexHeevd>;
+template struct ComplexHeevd>;
+
+// LAPACK uses a packed representation to represent a mixture of real
+// eigenvectors and complex conjugate pairs. This helper unpacks the
+// representation into regular complex matrices.
+template
+static void UnpackEigenvectors(int n, const T *im_eigenvalues, const T *packed,
+ std::complex *unpacked)
+{
+ T re, im;
+ int j;
+ j = 0;
+ while (j < n) {
+ if (im_eigenvalues[j] == 0. || std::isnan(im_eigenvalues[j])) {
+ for (int k = 0; k < n; ++k) {
+ unpacked[j * n + k] = {packed[j * n + k], 0.};
+ }
+ ++j;
+ }
+ else {
+ for (int k = 0; k < n; ++k) {
+ re = packed[j * n + k];
+ im = packed[(j + 1) * n + k];
+ unpacked[j * n + k] = {re, im};
+ unpacked[(j + 1) * n + k] = {re, -im};
+ }
+ j += 2;
+ }
+ }
+}
+
+// Geev
+// ~~~~
+
+template typename RealGeev::FnType *RealGeev::fn = nullptr;
+
+template void RealGeev