Skip to content

qttools.kernels.linalg.qr#

[docs] module qttools.kernels.linalg.qr

# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.

import numba as nb
import numpy as np

from qttools import NDArray, xp
from qttools.profiling import Profiler
from qttools.utils.gpu_utils import get_any_location, get_array_module_name

profiler = Profiler()


@profiler.profile(level="debug")
@nb.njit(parallel=True, cache=True, no_rewrites=True)
def _qr_numba(
    A: NDArray,
) -> tuple[NDArray, NDArray]:
    """Computes the QR decomposition of a batch of matrices.

    Parallelized with numba.

    Parameters
    ----------
    A : NDArray
        The matrices.

    Returns
    -------
    NDArray
        Unitary matrix Q in the QR decomposition.
    NDArray
        Upper triangular matrix R in the QR decomposition.

    """

    m = A.shape[-2]
    n = A.shape[-1]
    batch_size = A.shape[0]

    k = min(m, n)

    q = np.empty((batch_size, m, k), dtype=A.dtype)
    r = np.empty((batch_size, k, n), dtype=A.dtype)

    for i in nb.prange(batch_size):
        q_, r_ = np.linalg.qr(A[i])
        q[i] = q_
        r[i] = r_

    return q, r


@profiler.profile(level="api")
def qr(
    A: NDArray,
    compute_module: str = "numpy",
    output_module: str | None = None,
    use_pinned_memory: bool = True,
) -> tuple[NDArray, NDArray]:
    """Computes the QR decomposition of a batch of matrices.

    If compute_module is "numpy", the computation is done with numpy and parallelized with numba.
    Only mode 'reduced' is supported due to numba limitations.

    Parameters
    ----------
    A : NDArray
        The matrices.
    compute_module : str, optional
        The location where to compute the QR decomposition.
        Can be either "numpy" or "cupy".
    output_module : str, optional
        The location where to store the QR decomposition.
        Can be either "numpy"
        or "cupy". If None, the output location is the same as the input location
    use_pinned_memory : bool, optional
        Whether to use pinnend memory if cupy is used.
        Default is `True`.

    Returns
    -------
    NDArray
        Unitary matrix Q in the QR decomposition.
    NDArray
        Upper triangular matrix R in the QR decomposition.

    """

    input_module = get_array_module_name(A)

    if output_module is None:
        output_module = input_module

    if xp.__name__ == "numpy" and (
        compute_module == "cupy" or output_module == "cupy" or input_module == "cupy"
    ):
        raise ValueError("Cannot do gpu computation with numpy as xp.")

    # memcopy to correct location
    A = get_any_location(A, compute_module, use_pinned_memory=use_pinned_memory)

    if compute_module == "cupy":
        q, r = xp.linalg.qr(A)
    elif compute_module == "numpy":
        batch_shape = A.shape[:-2]
        m = A.shape[-2]
        n = A.shape[-1]
        A = A.reshape((-1, m, n))

        q, r = _qr_numba(A)

        k = min(m, n)
        q = q.reshape((*batch_shape, m, k))
        r = r.reshape((*batch_shape, k, n))

    return get_any_location(
        q, output_module, use_pinned_memory=use_pinned_memory
    ), get_any_location(r, output_module, use_pinned_memory=use_pinned_memory)