Skip to content

qttools.kernels.datastructure.cupy.dsdbcsr#

[docs] module qttools.kernels.datastructure.cupy.dsdbcsr

# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.

import cupy as cp
import numpy as np

from qttools import QTX_USE_CUPY_JIT, NDArray
from qttools.kernels.datastructure.cupy import THREADS_PER_BLOCK
from qttools.profiling import Profiler

if QTX_USE_CUPY_JIT:
    from qttools.kernels.datastructure.cupy import _cupy_jit as cupy_backend
else:
    from qttools.kernels.datastructure.cupy import _cupy_rawkernel as cupy_backend


profiler = Profiler()


@profiler.profile(level="api")
def find_inds(
    rowptr_map: dict,
    block_offsets: NDArray,
    self_cols: NDArray,
    rows: NDArray,
    cols: NDArray,
) -> tuple[NDArray, NDArray]:
    """Finds the corresponding indices of the given rows and columns.

    Parameters
    ----------
    rowptr_map : dict
        The row pointer map.
    block_offsets : NDArray
        The block offsets.
    self_cols : NDArray
        The columns of this matrix.
    rows : NDArray
        The rows to find the indices for.
    cols : NDArray
        The columns to find the indices for.

    Returns
    -------
    inds : NDArray
        The indices of the given rows and columns.
    value_inds : NDArray
        The matching indices of this matrix.

    """
    brows = cp.zeros(rows.shape[0], dtype=cp.int32)
    bcols = cp.zeros(cols.shape[0], dtype=cp.int32)
    block_offsets = block_offsets.astype(cp.int32)
    rows = rows.astype(cp.int32)
    cols = cols.astype(cp.int32)
    self_cols = self_cols.astype(cp.int32)

    bcoords_blocks_per_grid = (
        rows.shape[0] + THREADS_PER_BLOCK - 1
    ) // THREADS_PER_BLOCK

    cupy_backend._find_bcoords(
        (bcoords_blocks_per_grid,),
        (THREADS_PER_BLOCK,),
        (
            block_offsets,
            rows,
            cols,
            brows,
            bcols,
            np.int32(rows.shape[0]),
            np.int32(block_offsets.shape[0]),
        ),
    )
    # Get an ordered list of unique blocks.
    unique_blocks = dict.fromkeys(zip(map(int, brows), map(int, bcols))).keys()

    block_mask_blocks_per_grid = (
        brows.shape[0] + THREADS_PER_BLOCK - 1
    ) // THREADS_PER_BLOCK

    inds, value_inds = [], []
    for brow, bcol in unique_blocks:
        rowptr = rowptr_map.get((brow, bcol), None)
        if rowptr is None:
            continue
        mask = cp.zeros(brows.shape[0], dtype=cp.bool_)
        cupy_backend._compute_block_mask(
            (block_mask_blocks_per_grid,),
            (THREADS_PER_BLOCK,),
            (
                brows,
                bcols,
                brow,
                bcol,
                mask,
                np.int32(brows.shape[0]),
            ),
        )
        mask_inds = cp.nonzero(mask)[0]

        # Renormalize the row indices for this block.
        rr = rows[mask] - block_offsets[brow]
        cc = cols[mask]

        rowptr = rowptr.astype(cp.int32)

        block_inds = cp.zeros(rr.shape[0], dtype=cp.int32)
        blocks_per_grid = (rr.shape[0] + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
        cupy_backend._compute_block_inds(
            (blocks_per_grid,),
            (THREADS_PER_BLOCK,),
            (
                rr,
                cc,
                self_cols,
                rowptr,
                block_inds,
                np.int32(rr.shape[0]),
            ),
        )

        valid = block_inds != -1

        inds.extend(block_inds[valid])
        value_inds.extend(mask_inds[valid])

    return cp.array(inds, dtype=int), cp.array(value_inds, dtype=int)


@profiler.profile(level="api")
def densify_block(
    block: NDArray,
    block_offset: NDArray,
    self_cols: NDArray,
    rowptr: NDArray,
    data: NDArray,
):
    """Fills the dense block with the given data.

    Parameters
    ----------
    block : NDArray
        Preallocated dense block. Should be filled with zeros.
    block_offset : NDArray
        The block offset.
    self_cols : NDArray
        The column indices of this matrix.
    rowptr : NDArray
        The row pointer of this matrix block.
    data : NDArray
        The data to fill the block with.

    """
    cols = self_cols[rowptr[0] : rowptr[-1]] - block_offset
    rows = cp.zeros(cols.shape[0], dtype=cp.int32)
    blocks_per_grid = (rowptr.shape[0] + THREADS_PER_BLOCK - 2) // THREADS_PER_BLOCK

    rowptr = rowptr.astype(cp.int32)

    cupy_backend._expand_rows(
        (blocks_per_grid,),
        (THREADS_PER_BLOCK,),
        (
            rows,
            rowptr - rowptr[0],
            np.int32(rowptr.shape[0]),
        ),
    )
    block[..., rows, cols] = data[..., rowptr[0] : rowptr[-1]]


@profiler.profile(level="api")
def sparsify_block(
    block: NDArray,
    block_offset: NDArray,
    self_cols: NDArray,
    rowptr: NDArray,
    data: NDArray,
):
    """Fills the data with the given dense block.

    Parameters
    ----------
    block : NDArray
        The dense block to sparsify.
    block_offset : NDArray
        The block offset.
    self_cols : NDArray
        The column indices of this matrix.
    rowptr : NDArray
        The row pointer of this matrix block.
    data : NDArray
        The data to be filled with the block.

    """
    cols = self_cols[rowptr[0] : rowptr[-1]] - block_offset
    rows = cp.zeros(cols.shape[0], dtype=cp.int32)
    blocks_per_grid = (rowptr.shape[0] + THREADS_PER_BLOCK) // THREADS_PER_BLOCK

    rowptr = rowptr.astype(cp.int32)

    cupy_backend._expand_rows(
        (blocks_per_grid,),
        (THREADS_PER_BLOCK,),
        (
            rows,
            rowptr - rowptr[0],
            np.int32(rowptr.shape[0]),
        ),
    )
    data[..., rowptr[0] : rowptr[-1]] = block[..., rows, cols]


@profiler.profile(level="api")
def compute_rowptr_map(
    coo_rows: NDArray, coo_cols: NDArray, block_sizes: NDArray
) -> dict:
    """Computes the block-sorting index and the rowptr map.

    Note
    ----
    This is a combination of the bare block-sorting index computation
    and the rowptr map computation.

    Parameters
    ----------
    coo_rows : NDArray
        The row indices of the matrix in coordinate format.
    coo_cols : NDArray
        The column indices of the matrix in coordinate format.
    block_sizes : NDArray
        The block sizes of the block-sparse matrix we want to construct.

    Returns
    -------
    sort_index : NDArray
        The block-sorting index for the sparse matrix.
    rowptr_map : dict
        The row pointer map, describing the block-sparse matrix in
        blockwise column-sparse-row format.

    """
    num_blocks = block_sizes.shape[0]
    block_offsets = np.hstack((np.array([0]), np.cumsum(block_sizes)), dtype=np.int32)

    sort_index = cp.zeros(len(coo_cols), dtype=cp.int32)
    rowptr_map = {}
    mask = cp.zeros(len(coo_cols), dtype=cp.int32)

    coo_rows = coo_rows.astype(cp.int32)
    coo_cols = coo_cols.astype(cp.int32)

    blocks_per_grid = (len(coo_cols) + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
    offset = 0
    for i, j in cp.ndindex(num_blocks, num_blocks):
        cupy_backend._compute_coo_block_mask(
            (blocks_per_grid,),
            (THREADS_PER_BLOCK,),
            (
                coo_rows,
                coo_cols,
                np.int32(block_offsets[i]),
                np.int32(block_offsets[i + 1]),
                np.int32(block_offsets[j]),
                np.int32(block_offsets[j + 1]),
                mask,
                np.int32(len(coo_rows)),
            ),
        )

        if QTX_USE_CUPY_JIT:
            bnnz = cp.sum(mask)
        else:
            bnnz = cupy_backend.reduction(mask)

        if bnnz != 0:
            # Sort the data by block-row and -column.
            sort_index[offset : offset + bnnz] = cp.nonzero(mask)[0]

            # Compute the rowptr map.
            hist, __ = cp.histogram(
                coo_rows[mask.astype(cp.bool_)] - block_offsets[i],
                bins=cp.arange(block_sizes[i] + 1),
            )
            rowptr = cp.hstack((cp.array([0]), cp.cumsum(hist))) + offset
            rowptr_map[(i, j)] = rowptr

            offset += bnnz

    return sort_index, rowptr_map