# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.
import numba as nb
import numpy as np
from qttools import NDArray, xp
from qttools.utils.gpu_utils import get_any_location, get_array_module_name
@nb.njit(parallel=True, cache=True, no_rewrites=True)
def _qr_numba(
A: NDArray,
) -> tuple[NDArray, NDArray]:
"""Computes the QR decomposition of a batch of matrices.
Parallelized with numba.
Parameters
----------
A : NDArray
The matrices.
Returns
-------
NDArray
Unitary matrix Q in the QR decomposition.
NDArray
Upper triangular matrix R in the QR decomposition.
"""
m = A.shape[-2]
n = A.shape[-1]
batch_size = A.shape[0]
k = min(m, n)
q = np.empty((batch_size, m, k), dtype=A.dtype)
r = np.empty((batch_size, k, n), dtype=A.dtype)
for i in nb.prange(batch_size):
q_, r_ = np.linalg.qr(A[i])
q[i] = q_
r[i] = r_
return q, r
def qr(
A: NDArray,
compute_module: str = "numpy",
output_module: str | None = None,
use_pinned_memory: bool = True,
) -> tuple[NDArray, NDArray]:
"""Computes the QR decomposition of a batch of matrices.
If compute_module is "numpy", the computation is done with numpy and parallelized with numba.
Only mode 'reduced' is supported due to numba limitations.
Parameters
----------
A : NDArray
The matrices.
compute_module : str, optional
The location where to compute the QR decomposition.
Can be either "numpy" or "cupy".
output_module : str, optional
The location where to store the QR decomposition.
Can be either "numpy"
or "cupy". If None, the output location is the same as the input location
use_pinned_memory : bool, optional
Whether to use pinnend memory if cupy is used.
Default is `True`.
Returns
-------
NDArray
Unitary matrix Q in the QR decomposition.
NDArray
Upper triangular matrix R in the QR decomposition.
"""
input_module = get_array_module_name(A)
if output_module is None:
output_module = input_module
if xp.__name__ == "numpy" and (
compute_module == "cupy" or output_module == "cupy" or input_module == "cupy"
):
raise ValueError("Cannot do gpu computation with numpy as xp.")
# memcopy to correct location
A = get_any_location(A, compute_module, use_pinned_memory=use_pinned_memory)
if compute_module == "cupy":
q, r = xp.linalg.qr(A)
elif compute_module == "numpy":
batch_shape = A.shape[:-2]
m = A.shape[-2]
n = A.shape[-1]
A = A.reshape((-1, m, n))
q, r = _qr_numba(A)
k = min(m, n)
q = q.reshape((*batch_shape, m, k))
r = r.reshape((*batch_shape, k, n))
return get_any_location(
q, output_module, use_pinned_memory=use_pinned_memory
), get_any_location(r, output_module, use_pinned_memory=use_pinned_memory)