# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.
import cupy as cp
import numpy as np
from qttools import QTX_USE_CUPY_JIT, NDArray
from qttools.kernels.datastructure.cupy import THREADS_PER_BLOCK
from qttools.profiling import Profiler
if QTX_USE_CUPY_JIT:
from qttools.kernels.datastructure.cupy import _cupy_jit as cupy_backend
else:
from qttools.kernels.datastructure.cupy import _cupy_rawkernel as cupy_backend
profiler = Profiler()
@profiler.profile(level="api")
def find_inds(
rowptr_map: dict,
block_offsets: NDArray,
self_cols: NDArray,
rows: NDArray,
cols: NDArray,
) -> tuple[NDArray, NDArray]:
"""Finds the corresponding indices of the given rows and columns.
Parameters
----------
rowptr_map : dict
The row pointer map.
block_offsets : NDArray
The block offsets.
self_cols : NDArray
The columns of this matrix.
rows : NDArray
The rows to find the indices for.
cols : NDArray
The columns to find the indices for.
Returns
-------
inds : NDArray
The indices of the given rows and columns.
value_inds : NDArray
The matching indices of this matrix.
"""
brows = cp.zeros(rows.shape[0], dtype=cp.int32)
bcols = cp.zeros(cols.shape[0], dtype=cp.int32)
block_offsets = block_offsets.astype(cp.int32)
rows = rows.astype(cp.int32)
cols = cols.astype(cp.int32)
self_cols = self_cols.astype(cp.int32)
bcoords_blocks_per_grid = (
rows.shape[0] + THREADS_PER_BLOCK - 1
) // THREADS_PER_BLOCK
cupy_backend._find_bcoords(
(bcoords_blocks_per_grid,),
(THREADS_PER_BLOCK,),
(
block_offsets,
rows,
cols,
brows,
bcols,
np.int32(rows.shape[0]),
np.int32(block_offsets.shape[0]),
),
)
# Get an ordered list of unique blocks.
unique_blocks = dict.fromkeys(zip(map(int, brows), map(int, bcols))).keys()
block_mask_blocks_per_grid = (
brows.shape[0] + THREADS_PER_BLOCK - 1
) // THREADS_PER_BLOCK
inds, value_inds = [], []
for brow, bcol in unique_blocks:
rowptr = rowptr_map.get((brow, bcol), None)
if rowptr is None:
continue
mask = cp.zeros(brows.shape[0], dtype=cp.bool_)
cupy_backend._compute_block_mask(
(block_mask_blocks_per_grid,),
(THREADS_PER_BLOCK,),
(
brows,
bcols,
brow,
bcol,
mask,
np.int32(brows.shape[0]),
),
)
mask_inds = cp.nonzero(mask)[0]
# Renormalize the row indices for this block.
rr = rows[mask] - block_offsets[brow]
cc = cols[mask]
rowptr = rowptr.astype(cp.int32)
block_inds = cp.zeros(rr.shape[0], dtype=cp.int32)
blocks_per_grid = (rr.shape[0] + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
cupy_backend._compute_block_inds(
(blocks_per_grid,),
(THREADS_PER_BLOCK,),
(
rr,
cc,
self_cols,
rowptr,
block_inds,
np.int32(rr.shape[0]),
),
)
valid = block_inds != -1
inds.extend(block_inds[valid])
value_inds.extend(mask_inds[valid])
return cp.array(inds, dtype=int), cp.array(value_inds, dtype=int)
@profiler.profile(level="api")
def densify_block(
block: NDArray,
block_offset: NDArray,
self_cols: NDArray,
rowptr: NDArray,
data: NDArray,
):
"""Fills the dense block with the given data.
Parameters
----------
block : NDArray
Preallocated dense block. Should be filled with zeros.
block_offset : NDArray
The block offset.
self_cols : NDArray
The column indices of this matrix.
rowptr : NDArray
The row pointer of this matrix block.
data : NDArray
The data to fill the block with.
"""
cols = self_cols[rowptr[0] : rowptr[-1]] - block_offset
rows = cp.zeros(cols.shape[0], dtype=cp.int32)
blocks_per_grid = (rowptr.shape[0] + THREADS_PER_BLOCK - 2) // THREADS_PER_BLOCK
rowptr = rowptr.astype(cp.int32)
cupy_backend._expand_rows(
(blocks_per_grid,),
(THREADS_PER_BLOCK,),
(
rows,
rowptr - rowptr[0],
np.int32(rowptr.shape[0]),
),
)
block[..., rows, cols] = data[..., rowptr[0] : rowptr[-1]]
@profiler.profile(level="api")
def sparsify_block(
block: NDArray,
block_offset: NDArray,
self_cols: NDArray,
rowptr: NDArray,
data: NDArray,
):
"""Fills the data with the given dense block.
Parameters
----------
block : NDArray
The dense block to sparsify.
block_offset : NDArray
The block offset.
self_cols : NDArray
The column indices of this matrix.
rowptr : NDArray
The row pointer of this matrix block.
data : NDArray
The data to be filled with the block.
"""
cols = self_cols[rowptr[0] : rowptr[-1]] - block_offset
rows = cp.zeros(cols.shape[0], dtype=cp.int32)
blocks_per_grid = (rowptr.shape[0] + THREADS_PER_BLOCK) // THREADS_PER_BLOCK
rowptr = rowptr.astype(cp.int32)
cupy_backend._expand_rows(
(blocks_per_grid,),
(THREADS_PER_BLOCK,),
(
rows,
rowptr - rowptr[0],
np.int32(rowptr.shape[0]),
),
)
data[..., rowptr[0] : rowptr[-1]] = block[..., rows, cols]
@profiler.profile(level="api")
def compute_rowptr_map(
coo_rows: NDArray, coo_cols: NDArray, block_sizes: NDArray
) -> dict:
"""Computes the block-sorting index and the rowptr map.
Note
----
This is a combination of the bare block-sorting index computation
and the rowptr map computation.
Parameters
----------
coo_rows : NDArray
The row indices of the matrix in coordinate format.
coo_cols : NDArray
The column indices of the matrix in coordinate format.
block_sizes : NDArray
The block sizes of the block-sparse matrix we want to construct.
Returns
-------
sort_index : NDArray
The block-sorting index for the sparse matrix.
rowptr_map : dict
The row pointer map, describing the block-sparse matrix in
blockwise column-sparse-row format.
"""
num_blocks = block_sizes.shape[0]
block_offsets = np.hstack((np.array([0]), np.cumsum(block_sizes)), dtype=np.int32)
sort_index = cp.zeros(len(coo_cols), dtype=cp.int32)
rowptr_map = {}
mask = cp.zeros(len(coo_cols), dtype=cp.int32)
coo_rows = coo_rows.astype(cp.int32)
coo_cols = coo_cols.astype(cp.int32)
blocks_per_grid = (len(coo_cols) + THREADS_PER_BLOCK - 1) // THREADS_PER_BLOCK
offset = 0
for i, j in cp.ndindex(num_blocks, num_blocks):
cupy_backend._compute_coo_block_mask(
(blocks_per_grid,),
(THREADS_PER_BLOCK,),
(
coo_rows,
coo_cols,
np.int32(block_offsets[i]),
np.int32(block_offsets[i + 1]),
np.int32(block_offsets[j]),
np.int32(block_offsets[j + 1]),
mask,
np.int32(len(coo_rows)),
),
)
if QTX_USE_CUPY_JIT:
bnnz = cp.sum(mask)
else:
bnnz = cupy_backend.reduction(mask)
if bnnz != 0:
# Sort the data by block-row and -column.
sort_index[offset : offset + bnnz] = cp.nonzero(mask)[0]
# Compute the rowptr map.
hist, __ = cp.histogram(
coo_rows[mask.astype(cp.bool_)] - block_offsets[i],
bins=cp.arange(block_sizes[i] + 1),
)
rowptr = cp.hstack((cp.array([0]), cp.cumsum(hist))) + offset
rowptr_map[(i, j)] = rowptr
offset += bnnz
return sort_index, rowptr_map