# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.
import numba as nb
import numpy as np
from numpy.typing import NDArray
from qttools.profiling import Profiler
profiler = Profiler()
@profiler.profile(level="api")
@nb.njit(parallel=True, cache=True, no_rewrites=True)
def find_inds(
self_rows: NDArray, self_cols: NDArray, rows: NDArray, cols: NDArray
) -> tuple[NDArray, NDArray, int]:
"""Finds the corresponding indices of the given rows and columns.
This also counts the number of matches found, which is used to check
if the indices contain duplicates.
Parameters
----------
self_rows : NDArray
The rows of this matrix.
self_cols : NDArray
The columns of this matrix.
rows : NDArray
The rows to find the indices for.
cols : NDArray
The columns to find the indices for.
Returns
-------
inds : NDArray
The indices of the given rows and columns.
value_inds : NDArray
The matching indices of this matrix.
max_counts : int
The maximum number of matches found.
"""
full_inds = np.zeros(self_rows.shape[0], dtype=np.int32)
counts = np.zeros(self_rows.shape[0], dtype=np.int16)
for i in nb.prange(self_rows.shape[0]):
for j in range(rows.shape[0]):
cond = int((self_rows[i] == rows[j]) & (self_cols[i] == cols[j]))
full_inds[i] = full_inds[i] * (1 - cond) + j * cond
counts[i] += cond
# Find the valid indices.
inds = np.nonzero(counts)[0]
value_inds = full_inds[inds]
if counts.size == 0:
# No data in this block, return an empty slice.
return inds, value_inds, 0
return inds, value_inds, np.max(counts)
@profiler.profile(level="api")
@nb.njit(parallel=True, cache=True)
def compute_block_slice(
rows: NDArray, cols: NDArray, block_offsets: NDArray, row: int, col: int
):
"""Computes the slice in the data for the given block.
Parameters
----------
rows : NDArray
The rows in the COO matrix.
cols : NDArray
The columns in the COO matrix.
block_offsets : NDArray
The block offsets.
row : int
THe block row.
col : int
The block column.
Returns
-------
start : int
The start index of the block.
stop : int
The stop index of the block.
"""
mask = np.zeros(rows.shape[0], dtype=np.bool_)
row_start, row_stop = block_offsets[row], block_offsets[row + 1]
col_start, col_stop = block_offsets[col], block_offsets[col + 1]
for i in nb.prange(rows.shape[0]):
mask[i] = (
(rows[i] >= row_start)
& (rows[i] < row_stop)
& (cols[i] >= col_start)
& (cols[i] < col_stop)
)
if np.sum(mask) == 0:
# No data in this block, return an empty slice.
return None, None
# NOTE: The data is sorted by block-row and -column, so
# we can safely assume that the block is contiguous.
inds = np.nonzero(mask)[0]
return inds[0], inds[-1] + 1
@profiler.profile(level="api")
@nb.njit(parallel=True, cache=True)
def densify_block(
block: NDArray,
rows: NDArray,
cols: NDArray,
data: NDArray,
block_slice: slice,
row_offset: int,
col_offset: int,
):
"""Fills the dense block with the given data.
Note
----
If the blocks to be densified get very small, the overhead of
starting the CPU threads can lead to worse performance in the jitted
version than in the bare API implementation.
Parameters
----------
block : NDArray
Preallocated dense block. Should be filled with zeros.
rows : NDArray
The rows at which to fill the block.
cols : NDArray
The columns at which to fill the block.
data : NDArray
The data to fill the block with.
block_slice : slice
The slice of the block to fill.
row_offset : int
The row offset of the block.
col_offset : int
The column offset of the block
"""
# NOTE: We assume that the block is contiguous.
# Will not work if the block is not contiguous.
block_start = block_slice.start or 0
nnz_per_block = block_slice.stop - block_start
for idx in nb.prange(block_start, block_start + nnz_per_block):
row_idx = rows[idx] - row_offset
col_idx = cols[idx] - col_offset
block[..., row_idx, col_idx] = data[..., idx]
@profiler.profile(level="api")
@nb.njit(parallel=True, cache=True)
def sparsify_block(block: NDArray, rows: NDArray, cols: NDArray, data: NDArray):
"""Fills the data with the given dense block.
Parameters
----------
block : NDArray
The dense block to sparsify.
rows : NDArray
The rows at which to fill the block.
cols : NDArray
The columns at which to fill the block.
data : NDArray
The data to be filled with the block.
"""
for i in nb.prange(rows.shape[0]):
data[..., i] = block[..., rows[i], cols[i]]
@profiler.profile(level="api")
@nb.njit(parallel=True, cache=True)
def compute_block_sort_index(
coo_rows: NDArray, coo_cols: NDArray, block_sizes: NDArray
) -> NDArray:
"""Computes the block-sorting index for a sparse matrix.
Note
----
This method incurs a bit of memory overhead compared to a naive
implementation. No assumptions on the sparsity pattern of the matrix
are made here. See the source code for more details.
Parameters
----------
coo_rows : NDArray
The row indices of the matrix in coordinate format.
coo_cols : NDArray
The column indices of the matrix in coordinate format.
block_sizes : NDArray
The block sizes of the block-sparse matrix we want to construct.
Returns
-------
sort_index : NDArray
The indexing that sorts the data by block-row and -column.
"""
num_blocks = block_sizes.shape[0]
block_offsets = np.hstack((np.array([0]), np.cumsum(block_sizes)))
sort_index = np.zeros(len(coo_cols), dtype=np.int32)
# NOTE: This is a very generous estimate of the number of
# nonzeros in each row of blocks. No assumption on the sparsity
# pattern of the matrix is made here.
nnz_estimate = min(len(coo_cols), max(block_sizes) ** 2)
inds = np.zeros((num_blocks, nnz_estimate), dtype=np.int32)
block_nnz = np.zeros(num_blocks, dtype=np.int32)
nnz_offset = 0
for i in range(num_blocks):
# Precompute the row mask.
row_mask = (block_offsets[i] <= coo_rows) & (coo_rows < block_offsets[i + 1])
# Process in parallel.
for j in nb.prange(num_blocks):
mask = (
row_mask
& (block_offsets[j] <= coo_cols)
& (coo_cols < block_offsets[j + 1])
)
nnz = np.sum(mask)
block_nnz[j] = nnz
if nnz > 0:
inds[j, :nnz] = np.nonzero(mask)[0]
# Reduce the indices sequentially.
for j in range(num_blocks):
nnz = block_nnz[j]
if nnz > 0:
# Sort the data by block-row and -column.
sort_index[nnz_offset : nnz_offset + nnz] = inds[j, :nnz]
nnz_offset += nnz
return sort_index