Skip to content

qttools.kernels.linalg.svd#

[docs] module qttools.kernels.linalg.svd

# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.

import numba as nb
import numpy as np

from qttools import NDArray, xp
from qttools.utils.gpu_utils import get_any_location, get_array_module_name


@nb.njit(parallel=True, cache=True, no_rewrites=True)
def _svd_numba(
    A: NDArray,
    full_matrices: bool = True,
) -> tuple[NDArray, NDArray, NDArray]:
    """Computes the singular value decomposition of a batch of matrices.

    Parallelized over the batch dimension with numba. ndim of A must be 3.

    Parameters
    ----------
    A : NDArray
        The matrices.
    full_matrices : bool, optional
        Whether to compute the full matrices u and vh (see numpy.linalg.svd).

    Returns
    -------
    NDArray
        The left singular vectors.
    NDArray
        The singular values.
    NDArray
        The right singular vectors.

    """

    m = A.shape[-2]
    n = A.shape[-1]
    k = min(m, n)
    batch_size = A.shape[0]

    s = np.empty((batch_size, k), dtype=A.dtype)
    if full_matrices:
        u = np.empty((batch_size, m, m), dtype=A.dtype)
        vh = np.empty((batch_size, n, n), dtype=A.dtype)
    else:
        u = np.empty((batch_size, m, k), dtype=A.dtype)
        vh = np.empty((batch_size, k, n), dtype=A.dtype)

    for i in nb.prange(batch_size):
        u_, s_, vh_ = np.linalg.svd(A[i], full_matrices=full_matrices)
        s[i] = s_
        u[i] = u_
        vh[i] = vh_

    return u, s, vh


def svd(
    A: NDArray,
    full_matrices: bool = True,
    compute_module: str = "numpy",
    output_module: str | None = None,
    use_pinned_memory: bool = True,
) -> tuple[NDArray, NDArray, NDArray]:
    """Computes the singular value decomposition of a matrix on a given location.

    The kwargs compute_uv and hermitian are not supported.
    They are implicitly set to True and False, respectively.

    Parameters
    ----------
    A : NDArray
        The matrix.
    full_matrices : bool, optional
        Whether to compute the full matrices u and vh (see numpy.linalg.svd).
    compute_module : str, optional
        The location where to compute the singular value decomposition.
        Can be either "numpy" or "cupy".
    output_module : str, optional
        The location where to store the singular value decomposition.
        Can be either "numpy"
        or "cupy". If None, the output location is the same as the input location
    use_pinned_memory : bool, optional
        Whether to use pinnend memory if cupy is used.
        Default is `True`.

    Returns
    -------
    NDArray
        The left singular vectors.
    NDArray
        The singular values.
    NDArray
        The right singular vectors.

    """
    input_module = get_array_module_name(A)

    if output_module is None:
        output_module = input_module

    if xp.__name__ == "numpy" and (
        compute_module == "cupy" or output_module == "cupy" or input_module == "cupy"
    ):
        raise ValueError("Cannot do gpu computation with numpy as xp.")

    # memcopy to correct location
    A = get_any_location(A, compute_module, use_pinned_memory=use_pinned_memory)

    if compute_module == "cupy":
        u, s, vh = xp.linalg.svd(A, full_matrices=full_matrices)
    elif compute_module == "numpy":
        batch_shape = A.shape[:-2]
        m = A.shape[-2]
        n = A.shape[-1]
        A = A.reshape((-1, m, n))

        u, s, vh = _svd_numba(A, full_matrices)

        k = min(m, n)
        if full_matrices:
            u = u.reshape((*batch_shape, m, m))
            vh = vh.reshape((*batch_shape, n, n))
        else:
            u = u.reshape((*batch_shape, m, k))
            vh = vh.reshape((*batch_shape, k, n))
        s = s.reshape((*batch_shape, k))

    return (
        get_any_location(u, output_module, use_pinned_memory=use_pinned_memory),
        get_any_location(s, output_module, use_pinned_memory=use_pinned_memory),
        get_any_location(vh, output_module, use_pinned_memory=use_pinned_memory),
    )