From c1391c0b35397bb03a2aed927a6d097b4a7260ba Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 24 Sep 2024 09:46:44 +0100 Subject: [PATCH] Expose the `linalg` namespace and include in status page (#581) * Add matmul, matrix_transpose, tensordot, vecdot to linalg namespace * Move outer to linalg namespace * Remove flip from list of unimplemented functions since it was added in #114 * Remove unstack from list of unimplemented functions since it was added in #575 * Add link to cumulative_sum PR * Add linalg table to status page --- api_status.md | 36 +++++++++++++++++++-- cubed/__init__.py | 3 +- cubed/array_api/__init__.py | 4 +-- cubed/array_api/linalg.py | 14 +++++++- cubed/array_api/linear_algebra_functions.py | 4 --- cubed/tests/test_array_api.py | 2 +- docs/array-api.md | 1 - 7 files changed, 50 insertions(+), 14 deletions(-) diff --git a/api_status.md b/api_status.md index bc6188393..de6551bac 100644 --- a/api_status.md +++ b/api_status.md @@ -1,6 +1,6 @@ ## Array API Coverage Implementation Status -Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [linear algebra extensions](https://data-apis.org/array-api/2022.12/extensions/linear_algebra_functions.html) and [Fourier transform functions¶](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported. +Cubed supports version [2022.12](https://data-apis.org/array-api/2022.12/index.html) of the Python array API standard, with a few exceptions noted below. The [Fourier transform functions](https://data-apis.org/array-api/2022.12/extensions/fourier_transform_functions.html) are *not* supported. Support for version [2023.12](https://data-apis.org/array-api/2023.12/index.html) is tracked in Cubed issue [#438](https://github.com/cubed-dev/cubed/issues/438). @@ -67,7 +67,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `squeeze` | :white_check_mark: | | | | | `stack` | :white_check_mark: | | | | | `tile` | :x: | 2023.12 | | -| | `unstack` | :x: | 2023.12 | | +| | `unstack` | :white_check_mark: | 2023.12 | | | Searching Functions | `argmax` | :white_check_mark: | | | | | `argmin` | :white_check_mark: | | | | | `nonzero` | :x: | | Shape is data dependent | @@ -79,7 +79,7 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `unique_values` | :x: | | Shape is data dependent | | Sorting Functions | `argsort` | :x: | | Not in Dask | | | `sort` | :x: | | Not in Dask | -| Statistical Functions | `cumulative_sum` | :x: | 2023.12 | | +| Statistical Functions | `cumulative_sum` | :x: | 2023.12 | WIP [#531](https://github.com/cubed-dev/cubed/pull/531) | | | `max` | :white_check_mark: | | | | | `mean` | :white_check_mark: | | | | | `min` | :white_check_mark: | | | @@ -89,3 +89,33 @@ This table shows which parts of the the [Array API](https://data-apis.org/array- | | `var` | :x: | | Like `mean`, [#29](https://github.com/cubed-dev/cubed/issues/29) | | Utility Functions | `all` | :white_check_mark: | | | | | `any` | :white_check_mark: | | | + +### Linear Algebra Extension + +A few of the [linear algebra extension](https://data-apis.org/array-api/2022.12/extensions/linear_algebra_functions.html) functions are supported, as indicated in this table. + +| Category | Object/Function | Implemented | Version | Notes | +| ------------------------ | ------------------- | ------------------ | ---------- | ---------------------------- | +| Linear Algebra Functions | `cholesky` | :x: | | | +| | `cross` | :x: | | | +| | `det` | :x: | | | +| | `diagonal` | :x: | | | +| | `eigh` | :x: | | | +| | `eigvalsh` | :x: | | | +| | `inv` | :x: | | | +| | `matmul` | :white_check_mark: | | | +| | `matrix_norm` | :x: | | | +| | `matrix_power` | :x: | | | +| | `matrix_rank` | :x: | | | +| | `matrix_transpose` | :white_check_mark: | | | +| | `outer` | :white_check_mark: | | | +| | `pinv` | :x: | | | +| | `qr` | :white_check_mark: | | | +| | `slogdet` | :x: | | | +| | `solve` | :x: | | | +| | `svd` | :x: | | | +| | `svdvals` | :x: | | | +| | `tensordot` | :white_check_mark: | | | +| | `trace` | :x: | | | +| | `vecdot` | :white_check_mark: | | | +| | `vectornorm` | :x: | | | diff --git a/cubed/__init__.py b/cubed/__init__.py index ed529ab73..790d54d48 100644 --- a/cubed/__init__.py +++ b/cubed/__init__.py @@ -267,12 +267,11 @@ from .array_api.linear_algebra_functions import ( matmul, matrix_transpose, - outer, tensordot, vecdot, ) -__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] +__all__ += ["matmul", "matrix_transpose", "tensordot", "vecdot"] from .array_api.manipulation_functions import ( broadcast_arrays, diff --git a/cubed/array_api/__init__.py b/cubed/array_api/__init__.py index e7e6ffbfe..993c9bf3d 100644 --- a/cubed/array_api/__init__.py +++ b/cubed/array_api/__init__.py @@ -212,9 +212,9 @@ __all__ += ["take"] -from .linear_algebra_functions import matmul, matrix_transpose, outer, tensordot, vecdot +from .linear_algebra_functions import matmul, matrix_transpose, tensordot, vecdot -__all__ += ["matmul", "matrix_transpose", "outer", "tensordot", "vecdot"] +__all__ += ["matmul", "matrix_transpose", "tensordot", "vecdot"] from .manipulation_functions import ( broadcast_arrays, diff --git a/cubed/array_api/linalg.py b/cubed/array_api/linalg.py index 03fb072af..91b253944 100644 --- a/cubed/array_api/linalg.py +++ b/cubed/array_api/linalg.py @@ -1,11 +1,23 @@ from typing import NamedTuple from cubed.array_api.array_object import Array + +# These functions are in both the main and linalg namespaces +from cubed.array_api.linear_algebra_functions import ( # noqa: F401 + matmul, + matrix_transpose, + tensordot, + vecdot, +) from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import general_blockwise, map_direct, merge_chunks +from cubed.core.ops import blockwise, general_blockwise, map_direct, merge_chunks from cubed.utils import array_memory, get_item +def outer(x1, x2, /): + return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype) + + class QRResult(NamedTuple): Q: Array R: Array diff --git a/cubed/array_api/linear_algebra_functions.py b/cubed/array_api/linear_algebra_functions.py index 8e5829f4f..538fb2a02 100644 --- a/cubed/array_api/linear_algebra_functions.py +++ b/cubed/array_api/linear_algebra_functions.py @@ -95,10 +95,6 @@ def matrix_transpose(x, /): return permute_dims(x, axes) -def outer(x1, x2, /): - return blockwise(nxp.linalg.outer, "ij", x1, "i", x2, "j", dtype=x1.dtype) - - def tensordot(x1, x2, /, *, axes=2, use_new_impl=True, split_every=None): from cubed.array_api.statistical_functions import sum diff --git a/cubed/tests/test_array_api.py b/cubed/tests/test_array_api.py index b34b8bbba..21c10d4ab 100644 --- a/cubed/tests/test_array_api.py +++ b/cubed/tests/test_array_api.py @@ -417,7 +417,7 @@ def test_matmul_modal(modal_executor): def test_outer(spec, executor): a = xp.asarray([0, 1, 2], chunks=2, spec=spec) b = xp.asarray([10, 50, 100], chunks=2, spec=spec) - c = xp.outer(a, b) + c = xp.linalg.outer(a, b) assert_array_equal(c.compute(executor=executor), np.outer([0, 1, 2], [10, 50, 100])) diff --git a/docs/array-api.md b/docs/array-api.md index d6550ee32..9b0d5d40f 100644 --- a/docs/array-api.md +++ b/docs/array-api.md @@ -15,7 +15,6 @@ The following parts of the standard are not implemented: | Array object | In-place Ops | | Creation Functions | `from_dlpack` | | Indexing | Boolean array | -| Manipulation Functions | `flip` | | Searching Functions | `nonzero` | | Set Functions | `unique_all` | | | `unique_counts` |