Skip to content

qttools.kernels.datastructure.numba.dsdbsparse#

[docs] module qttools.kernels.datastructure.numba.dsdbsparse

# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.

import numba as nb
import numpy as np
from numpy.typing import NDArray

from qttools.profiling import Profiler

profiler = Profiler()


@profiler.profile(level="api")
@nb.njit(parallel=True, cache=True)
def find_ranks(nnz_section_offsets: NDArray, inds: NDArray) -> NDArray:
    """Find the ranks of the indices in the offsets.

    Parameters
    ----------
    nnz_section_offsets : NDArray
        The offsets of the non-zero sections.
    inds : NDArray
        The indices to find the ranks for.


    Returns
    -------
    ranks : NDArray
        The ranks of the indices in the offsets.

    """
    ranks = np.zeros(inds.shape[0], dtype=np.int16)
    for i in nb.prange(inds.shape[0]):
        for j in range(nnz_section_offsets.shape[0]):
            cond = int(nnz_section_offsets[j] <= inds[i])
            ranks[i] = ranks[i] * (1 - cond) + j * cond

    return ranks