diff --git a/api_status.md b/api_status.md index a71b54c7..ff1f2c59 100644 --- a/api_status.md +++ b/api_status.md @@ -113,8 +113,8 @@ A few of the [linear algebra extension](https://data-apis.org/array-api/2022.12/ | | `qr` | :white_check_mark: | | | | | `slogdet` | :x: | | | | | `solve` | :x: | | | -| | `svd` | :x: | | | -| | `svdvals` | :x: | | | +| | `svd` | :white_check_mark: | | | +| | `svdvals` | :white_check_mark: | | | | | `tensordot` | :white_check_mark: | | | | | `trace` | :x: | | | | | `vecdot` | :white_check_mark: | | | diff --git a/cubed/array_api/linalg.py b/cubed/array_api/linalg.py index 58b6f65b..8a907a23 100644 --- a/cubed/array_api/linalg.py +++ b/cubed/array_api/linalg.py @@ -1,10 +1,10 @@ from typing import NamedTuple from cubed.array_api.array_object import Array - -# These functions are in both the main and linalg namespaces from cubed.array_api.data_type_functions import result_type from cubed.array_api.dtypes import _floating_dtypes + +# These functions are in both the main and linalg namespaces from cubed.array_api.linear_algebra_functions import ( # noqa: F401 matmul, matrix_transpose, @@ -12,7 +12,7 @@ vecdot, ) from cubed.backend_array_api import namespace as nxp -from cubed.core.ops import blockwise, general_blockwise, merge_chunks +from cubed.core.ops import blockwise, general_blockwise, merge_chunks, squeeze from cubed.utils import array_memory, get_item @@ -27,6 +27,12 @@ class QRResult(NamedTuple): R: Array +class SVDResult(NamedTuple): + U: Array + S: Array + Vh: Array + + def qr(x, /, *, mode="reduced") -> QRResult: if x.ndim != 2: raise ValueError("qr requires x to have 2 dimensions.") @@ -43,10 +49,11 @@ def qr(x, /, *, mode="reduced") -> QRResult: "Consider rechunking so there is only a single column chunk." ) - return tsqr(x) + Q, R, _, _, _ = tsqr(x) + return QRResult(Q, R) -def tsqr(x) -> QRResult: +def tsqr(x, compute_svd=False, finalize_svd=True): """Direct Tall-and-Skinny QR algorithm From: @@ -57,18 +64,22 @@ def tsqr(x) -> QRResult: https://arxiv.org/abs/1301.1071 """ - # follows Algorithm 2 from Benson et al + # follows Algorithm 2 from Benson et al, modified for SVD Q1, R1 = _qr_first_step(x) if _r1_is_too_big(R1): R1 = _rechunk_r1(R1) - Q2, R2 = tsqr(R1) + Q2, R2, U, S, Vh = tsqr(R1, compute_svd=compute_svd, finalize_svd=False) else: - Q2, R2 = _qr_second_step(R1) + Q2, R2, U, S, Vh = _qr_second_step(R1, compute_svd=compute_svd) Q, R = _qr_third_step(Q1, Q2), R2 - return QRResult(Q, R) + if compute_svd and finalize_svd: + U = Q @ U # fourth step (SVD only) + S = squeeze(S, axis=1) # remove extra dim + + return Q, R, U, S, Vh def _qr_first_step(A): @@ -108,7 +119,7 @@ def _rechunk_r1(R1, split_every=4): return merge_chunks(R1, chunks=chunks) -def _qr_second_step(R1): +def _qr_second_step(R1, compute_svd=False): R1_single = _merge_into_single_chunk(R1) Q2_shape = R1.shape @@ -117,17 +128,38 @@ def _qr_second_step(R1): n = R1.shape[1] R2_shape = (n, n) R2_chunks = R2_shape # single chunk - # qr implementation creates internal array buffers - extra_projected_mem = R1_single.chunkmem * 4 - Q2, R2 = map_blocks_multiple_outputs( - nxp.linalg.qr, - R1_single, - shapes=[Q2_shape, R2_shape], - dtypes=[R1.dtype, R1.dtype], - chunkss=[Q2_chunks, R2_chunks], - extra_projected_mem=extra_projected_mem, - ) - return QRResult(Q2, R2) + + if not compute_svd: + # qr implementation creates internal array buffers + extra_projected_mem = R1_single.chunkmem * 4 + Q2, R2 = map_blocks_multiple_outputs( + nxp.linalg.qr, + R1_single, + shapes=[Q2_shape, R2_shape], + dtypes=[R1.dtype, R1.dtype], + chunkss=[Q2_chunks, R2_chunks], + extra_projected_mem=extra_projected_mem, + ) + return Q2, R2, None, None, None + else: + U_shape = (n, n) + U_chunks = U_shape + S_shape = (n, 1) # extra dim since multiple outputs must have same numblocks + S_chunks = S_shape + Vh_shape = (n, n) + Vh_chunks = Vh_shape + + # qr implementation creates internal array buffers + extra_projected_mem = R1_single.chunkmem * 4 + Q2, R2, U, S, Vh = map_blocks_multiple_outputs( + _qr2, + R1_single, + shapes=[Q2_shape, R2_shape, U_shape, S_shape, Vh_shape], + dtypes=[R1.dtype, R1.dtype, R1.dtype, R1.dtype, R1.dtype], + chunkss=[Q2_chunks, R2_chunks, U_chunks, S_chunks, Vh_chunks], + extra_projected_mem=extra_projected_mem, + ) + return Q2, R2, U, S, Vh def _merge_into_single_chunk(x, split_every=4): @@ -138,6 +170,13 @@ def _merge_into_single_chunk(x, split_every=4): return x +def _qr2(a): + Q, R = nxp.linalg.qr(a) + U, S, Vh = nxp.linalg.svd(R) + S = S[:, nxp.newaxis] # add extra dim + return Q, R, U, S, Vh + + def _qr_third_step(Q1, Q2): m, n = Q1.chunksize k, _ = Q1.numblocks @@ -174,6 +213,30 @@ def _q_matmul(a1, a2, q2_chunks=None, block_id=None): return q1 @ q2 +def svd(x, /, *, full_matrices=True) -> SVDResult: + if full_matrices: + raise ValueError("Cubed arrays only support using full_matrices=False") + + nb = x.numblocks + # TODO: optimize case nb[0] == nb[1] == 1 + if nb[0] > nb[1]: + _, _, U, S, Vh = tsqr(x, compute_svd=True) + truncate = x.shape[0] < x.shape[1] + else: + _, _, Vht, S, Ut = tsqr(x.T, compute_svd=True) + U, S, Vh = Ut.T, S, Vht.T + truncate = x.shape[0] > x.shape[1] + if truncate: # from dask + k = min(x.shape) + U, Vh = U[:, :k], Vh[:k, :] + return SVDResult(U, S, Vh) + + +def svdvals(x, /): + _, S, _ = svd(x, full_matrices=False) + return S + + def map_blocks_multiple_outputs( func, *args, diff --git a/cubed/tests/test_linalg.py b/cubed/tests/test_linalg.py index 0ef72636..585a72d4 100644 --- a/cubed/tests/test_linalg.py +++ b/cubed/tests/test_linalg.py @@ -57,3 +57,47 @@ def test_qr_chunking(): match=r"qr only supports tall-and-skinny \(single column chunk\) arrays.", ): xp.linalg.qr(A) + + +def test_svd(): + A = np.reshape(np.arange(32, dtype=np.float64), (16, 2)) + + U, S, Vh = xp.linalg.svd(xp.asarray(A, chunks=(4, 2)), full_matrices=False) + U, S, Vh = cubed.compute(U, S, Vh) + + assert_allclose(U * S @ Vh, A, atol=1e-08) + assert_allclose(U.T @ U, np.eye(2, 2), atol=1e-08) # U must be orthonormal + assert_allclose(Vh @ Vh.T, np.eye(2, 2), atol=1e-08) # Vh must be orthonormal + + +def test_svd_recursion(): + A = np.reshape(np.arange(128, dtype=np.float64), (64, 2)) + + # find a memory setting where recursion happens + found = False + for factor in range(4, 16): + spec = cubed.Spec(allowed_mem=128 * factor, reserved_mem=0) + + try: + U, S, Vh = xp.linalg.svd( + xp.asarray(A, chunks=(8, 2), spec=spec), full_matrices=False + ) + + found = True + plan_unopt = arrays_to_plan(U, S, Vh)._finalize() + assert plan_unopt.num_primitive_ops() > 4 # more than without recursion + + U, S, Vh = cubed.compute(U, S, Vh) + + assert_allclose(U * S @ Vh, A, atol=1e-08) + assert_allclose(U.T @ U, np.eye(2, 2), atol=1e-08) # U must be orthonormal + assert_allclose( + Vh @ Vh.T, np.eye(2, 2), atol=1e-08 + ) # Vh must be orthonormal + + break + + except ValueError: + pass # not enough memory + + assert found