Skip to content

qttools.kernels.linalg.qr#

[docs] module qttools.kernels.linalg.qr

# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.

import numba as nb
import numpy as np

from qttools import NDArray, xp
from qttools.utils.gpu_utils import get_any_location, get_array_module_name


@nb.njit(parallel=True, cache=True, no_rewrites=True)
def _qr_numba(
    A: NDArray,
) -> tuple[NDArray, NDArray]:
    """Computes the QR decomposition of a batch of matrices.

    Parallelized with numba.

    Parameters
    ----------
    A : NDArray
        The matrices.

    Returns
    -------
    NDArray
        Unitary matrix Q in the QR decomposition.
    NDArray
        Upper triangular matrix R in the QR decomposition.

    """

    m = A.shape[-2]
    n = A.shape[-1]
    batch_size = A.shape[0]

    k = min(m, n)

    q = np.empty((batch_size, m, k), dtype=A.dtype)
    r = np.empty((batch_size, k, n), dtype=A.dtype)

    for i in nb.prange(batch_size):
        q_, r_ = np.linalg.qr(A[i])
        q[i] = q_
        r[i] = r_

    return q, r


def qr(
    A: NDArray,
    compute_module: str = "numpy",
    output_module: str | None = None,
    use_pinned_memory: bool = True,
) -> tuple[NDArray, NDArray]:
    """Computes the QR decomposition of a batch of matrices.

    If compute_module is "numpy", the computation is done with numpy and parallelized with numba.
    Only mode 'reduced' is supported due to numba limitations.

    Parameters
    ----------
    A : NDArray
        The matrices.
    compute_module : str, optional
        The location where to compute the QR decomposition.
        Can be either "numpy" or "cupy".
    output_module : str, optional
        The location where to store the QR decomposition.
        Can be either "numpy"
        or "cupy". If None, the output location is the same as the input location
    use_pinned_memory : bool, optional
        Whether to use pinnend memory if cupy is used.
        Default is `True`.

    Returns
    -------
    NDArray
        Unitary matrix Q in the QR decomposition.
    NDArray
        Upper triangular matrix R in the QR decomposition.

    """

    input_module = get_array_module_name(A)

    if output_module is None:
        output_module = input_module

    if xp.__name__ == "numpy" and (
        compute_module == "cupy" or output_module == "cupy" or input_module == "cupy"
    ):
        raise ValueError("Cannot do gpu computation with numpy as xp.")

    # memcopy to correct location
    A = get_any_location(A, compute_module, use_pinned_memory=use_pinned_memory)

    if compute_module == "cupy":
        q, r = xp.linalg.qr(A)
    elif compute_module == "numpy":
        batch_shape = A.shape[:-2]
        m = A.shape[-2]
        n = A.shape[-1]
        A = A.reshape((-1, m, n))

        q, r = _qr_numba(A)

        k = min(m, n)
        q = q.reshape((*batch_shape, m, k))
        r = r.reshape((*batch_shape, k, n))

    return get_any_location(
        q, output_module, use_pinned_memory=use_pinned_memory
    ), get_any_location(r, output_module, use_pinned_memory=use_pinned_memory)