Skip to content

qttools.utils.inplace_utils#

[docs] module qttools.utils.inplace_utils

# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.

import numpy as np

from qttools import NDArray, sparse, xp


def compute_update_indices_sparse(
    M: sparse.csr_matrix, U: sparse.csr_matrix, destination_indexes: NDArray = None
) -> NDArray:
    """Computes the indices for updating the system matrix.

    Parameters
    ----------
    M : sparse.csr_matrix
        The original system matrix.
    U : sparse.csr_matrix
        The update matrix to be applied.
    destination_indexes : NDArray
        The indices in the system matrix where the update should be applied.

    Returns
    -------
    target_indices : NDArray
        The indices in the flattened system matrix corresponding to the
        update positions.

    """

    # Get the CPU versions of M and U
    M = M.get() if hasattr(M, "get") else M
    U = U.get() if hasattr(U, "get") else U

    # Default destination indexes to identity mapping
    if destination_indexes is None:
        destination_indexes = np.arange(M.shape[0], dtype=xp.int64)

    if np.unique(destination_indexes).size != destination_indexes.size:
        raise ValueError(
            "The destination indexes have duplicate entries, cannot compute update indices."
        )

    update_indices = np.zeros_like(U.data, dtype=xp.int64)

    # Iterate over rows of U
    for U_row in range(U.shape[0]):

        # Get the column indices for the current row of U
        row_start = U.indptr[U_row]
        row_end = U.indptr[U_row + 1]
        U_cols = U.indices[row_start:row_end]

        # Get the corresponding row in M
        M_row = destination_indexes[U_row]

        # Get the column indices for the current row of M
        M_row_start = M.indptr[M_row]
        M_row_end = M.indptr[M_row + 1]
        M_cols = M.indices[M_row_start:M_row_end]

        # Check for duplicate column indices in the system matrix row
        if np.unique(M_cols).size != M_cols.size:
            raise ValueError(
                "The system matrix has duplicate column indices in a row, cannot compute update indices."
            )

        # Map U column indices to destination indexes in M
        U_cols_dest = destination_indexes[U_cols]

        # Map U column indices to M column indices
        M_ind_map = np.searchsorted(M_cols, U_cols_dest)
        if (M_cols[M_ind_map] != U_cols_dest).any():
            raise ValueError(
                "Some destination indexes do not exist in the system matrix row, cannot compute update indices."
            )
        if np.unique(M_ind_map).size != U_cols_dest.size:
            raise ValueError(
                "Some destination indexes do not exist in the system matrix row, cannot compute update indices."
            )

        update_indices[row_start:row_end] = M_ind_map + M_row_start

    return xp.array(update_indices)


def compute_update_indices_dense(
    M: sparse.csr_matrix, destination_indexes: NDArray = None
) -> NDArray:
    """Computes the indices for updating the system matrix.

    Parameters
    ----------
    M : sparse.csr_matrix
        The original system matrix.
    U : NDArray
        The update matrix to be applied.
    destination_indexes : NDArray
        The indices in the system matrix where the update should be applied.

    Returns
    -------
    target_indices : NDArray
        The indices in the flattened system matrix corresponding to the
        update positions.

    """

    # Get the CPU version of M
    M = M.get() if hasattr(M, "get") else M

    # Default destination indexes to identity mapping
    if destination_indexes is None:
        destination_indexes = np.arange(M.shape[0], dtype=xp.int64)

    if np.unique(destination_indexes).size != destination_indexes.size:
        raise ValueError(
            "The destination indexes have duplicate entries, cannot compute update indices."
        )

    U_size = destination_indexes.shape[0]

    update_indices = np.zeros((U_size**2,), dtype=xp.int64)

    for U_row in range(U_size):

        # Get the corresponding row in M
        M_row = destination_indexes[U_row]
        M_row_start = M.indptr[M_row]
        M_row_end = M.indptr[M_row + 1]
        M_cols = M.indices[M_row_start:M_row_end]

        if np.unique(M_cols).size != M_cols.size:
            raise ValueError(
                "The system matrix has duplicate column indices in a row, cannot compute update indices."
            )

        M_ind_map = np.searchsorted(M_cols, destination_indexes)
        if np.unique(M_ind_map).size != destination_indexes.size:
            raise ValueError(
                "Some destination indexes do not exist in the system matrix row, cannot compute update indices."
            )
        if (M_cols[M_ind_map] != destination_indexes).any():
            raise ValueError(
                "Some destination indexes do not exist in the system matrix row, cannot compute update indices."
            )

        update_indices[U_row * U_size : (U_row + 1) * U_size] = M_ind_map + M_row_start

    return xp.array(update_indices)