Skip to content

qttools.utils.sparse_utils#

[docs] module qttools.utils.sparse_utils

# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.

import functools

from qttools import NDArray, sparse, xp
from qttools.datastructures.dsdbsparse import DSDBSparse
from qttools.datastructures.routines import BlockMatrix, bd_matmul_distr
from qttools.profiling import Profiler

profiler = Profiler()


@profiler.profile(level="debug")
def product_sparsity_pattern(
    *matrices: sparse.spmatrix,
) -> tuple[NDArray, NDArray]:
    """Computes the sparsity pattern of the product of a sequence of matrices.

    Parameters
    ----------
    matrices : sparse.spmatrix
        A sequence of sparse matrices.

    Returns
    -------
    rows : NDArray
        The row indices of the sparsity pattern.
    cols : NDArray
        The column indices of the sparsity pattern.

    """
    # NOTE: cupyx.scipy.sparse does not support bool dtype in matmul.
    csrs = [matrix.tocsr() for matrix in matrices]
    for i, csr in enumerate(csrs):
        if xp.iscomplexobj(csr.data):
            csr = csr.copy()
            csr.data = csr.data.real
        csrs[i] = csr.astype(xp.float32)
    product = functools.reduce(lambda x, y: x @ y, csrs)
    product = product.tocoo()
    # Canonicalize
    product.sum_duplicates()

    return product.row, product.col


def tocsr_dict(matrix: DSDBSparse) -> dict[tuple[int, int], sparse.csr_matrix]:
    """Converts a DSDBSparse matrix to a dictionary of CSR blocks."""

    blocks = {}

    for i in range(matrix.num_blocks):
        for j in range(matrix.num_blocks):
            sparse_data = matrix.sparse_blocks[i, j]
            data = xp.ones_like(sparse_data[0][-1], dtype=xp.float32)
            sparse_data = (data, *sparse_data[1:])
            blocks[i, j] = sparse.csr_matrix(
                sparse_data, shape=(matrix.block_sizes[i], matrix.block_sizes[j])
            )

    return blocks


def product_sparsity_pattern_dsdbsparse(
    *matrices: DSDBSparse,
    in_num_diag: int = 3,
    out_num_diag: int = None,
    start_block: int = 0,
    end_block: int = None,
    spillover: bool = False,
) -> tuple[NDArray, NDArray]:
    """Computes the sparsity pattern of the product of a sequence of DSDBSparse matrices.

    # NOTE: the enough boundary layers need to be periodic for this to be correct.


    Parameters
    ----------
    matrices : sparse.spmatrix
        A sequence of sparse matrices.

    Returns
    -------
    rows : NDArray
        The row indices of the sparsity pattern.
    cols : NDArray
        The column indices of the sparsity pattern.

    """

    assert len(matrices) > 1

    a = matrices[0]

    # Assuming that all matrices have the same number of blocks, same block sizes, and same block diagonals.
    num_blocks = a.num_blocks
    end_block = end_block or num_blocks
    block_offsets = a.block_offsets
    a_num_diag = in_num_diag

    local_keys = set()
    for i in range(start_block, end_block):
        for j in range(
            max(start_block, i - in_num_diag // 2),
            min(num_blocks, i + in_num_diag // 2 + 1),
        ):
            local_keys.add((i, j))
    for j in range(start_block, end_block):
        for i in range(
            max(end_block, j - in_num_diag // 2),
            min(num_blocks, j + in_num_diag // 2 + 1),
        ):
            local_keys.add((i, j))
    a_ = BlockMatrix(a, local_keys, (start_block, start_block))

    out_num_diag = out_num_diag or in_num_diag * len(matrices) - len(matrices) + 1

    for n, b in enumerate(matrices[1:]):

        b_ = BlockMatrix(b, local_keys, (start_block, start_block))
        b_num_diag = in_num_diag
        tmp_num_diag = a_num_diag + b_num_diag - 1
        if n == len(matrices) - 2:
            tmp_num_diag = out_num_diag

        c_ = bd_matmul_distr(
            a=a_,
            b=b,
            out=None,
            a_num_diag=a_num_diag,
            b_num_diag=b_num_diag,
            out_num_diag=tmp_num_diag,
            start_block=start_block,
            end_block=end_block,
            spillover_correction=False,
        )

        if spillover:
            # NOTE: The following only works if b is BTD.
            if start_block == 0:
                # Left spillover
                for i in range(a_num_diag // 2):
                    c_[i, 0] += a_[i + 1, 0] @ b_[0, 1]

            if end_block == num_blocks:
                # Right spillover
                for i in range(a_num_diag // 2):
                    c_[num_blocks - i - 1, num_blocks - 1] += (
                        a_[num_blocks - i - 2, num_blocks - 1]
                        @ b_[num_blocks - 1, num_blocks - 2]
                    )

        a_ = c_
        a_num_diag = tmp_num_diag

    c_rows = xp.empty(0, dtype=xp.int32)
    c_cols = xp.empty(0, dtype=xp.int32)

    local_keys = set()
    for i in range(start_block, end_block):
        for j in range(
            max(start_block, i - out_num_diag // 2),
            min(num_blocks, i + out_num_diag // 2 + 1),
        ):
            local_keys.add((i, j))
    for j in range(start_block, end_block):
        for i in range(
            max(end_block, j - out_num_diag // 2),
            min(num_blocks, j + out_num_diag // 2 + 1),
        ):
            local_keys.add((i, j))

    for i in range(num_blocks):
        for j in range(num_blocks):
            if (i, j) not in local_keys:
                continue
            c_block = c_[i, j]
            m, n = c_block.shape[-2:]
            c_block = sparse.coo_matrix(c_block.flat[: m * n].reshape(m, n))
            c_block.sum_duplicates()
            c_block.eliminate_zeros()
            c_rows = xp.append(c_rows, c_block.row + block_offsets[i])
            c_cols = xp.append(c_cols, c_block.col + block_offsets[j])

    return c_rows, c_cols