Skip to content

qttools.kernels.datastructure.cupy.dsdbsparse#

[docs] module qttools.kernels.datastructure.cupy.dsdbsparse

# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.


import cupy as cp
import numpy as np

from qttools import QTX_USE_CUPY_JIT, NDArray
from qttools.kernels.datastructure.cupy import THREADS_PER_BLOCK
from qttools.profiling import Profiler

if QTX_USE_CUPY_JIT:
    from qttools.kernels.datastructure.cupy import _cupy_jit as cupy_backend
else:
    from qttools.kernels.datastructure.cupy import _cupy_rawkernel as cupy_backend


profiler = Profiler()


@profiler.profile(level="api")
def find_ranks(nnz_section_offsets: NDArray, inds: NDArray) -> NDArray:
    """Finds the ranks of the indices in the offsets.

    Parameters
    ----------
    nnz_section_offsets : NDArray
        The offsets of the non-zero sections.
    inds : NDArray
        The indices to find the ranks for.

    Returns
    -------
    ranks : NDArray
        The ranks of the indices in the offsets.

    """
    ranks = cp.zeros(inds.shape[0], dtype=cp.int16)

    nnz_section_offsets = nnz_section_offsets.astype(cp.int32)
    inds = inds.astype(cp.int32)

    blocks_per_grid = (inds.shape[0] + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
    cupy_backend._find_ranks(
        (blocks_per_grid,),
        (THREADS_PER_BLOCK,),
        (
            nnz_section_offsets,
            inds,
            ranks,
            np.int32(nnz_section_offsets.shape[0]),
            np.int32(inds.shape[0]),
        ),
    )
    return ranks