Skip to content

qttools.greens_function_solver.rgf_dist#

[docs] module qttools.greens_function_solver.rgf_dist

# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.

import numpy as np

from qttools import NDArray
from qttools.comm import comm
from qttools.datastructures.dsdbsparse import DSDBSparse
from qttools.greens_function_solver import _serinv
from qttools.greens_function_solver.solver import GFSolver, OBCBlocks
from qttools.profiling import Profiler
from qttools.utils.solvers_utils import get_batches

profiler = Profiler()


class RGFDist(GFSolver):
    """Distributed selected inversion solver.

    Parameters
    ----------
    solve_lesser : bool, optional
        Whether to solve the quadratic system associated with the lesser right-hand-side,
        by default False.
    solve_greater : bool, optional
        Whether to solve the quadratic system associated with the greater right-hand-side,
        by default False.
    max_batch_size : int, optional
        Maximum batch size to use when inverting the matrix, by default
        100.

    """

    def __init__(self, max_batch_size: int = 100) -> None:
        """Initializes the selected inversion solver."""
        self.max_batch_size = max_batch_size

    def selected_inv(
        self,
        a: DSDBSparse,
        out: DSDBSparse,
        obc_blocks: OBCBlocks | None = None,
    ) -> None | DSDBSparse:
        """Performs selected inversion of a block-tridiagonal matrix.

        Parameters
        ----------
        a : DSDBSparse
            Matrix to invert.
        out : DSDBSparse, optional
            Preallocated output matrix, by default None.

        Returns
        -------
        None | DSDBSparse
            If `out` is None, returns None. Otherwise, returns the
            inverted matrix as a DSDBSparse object.

        """

        # Initialize temporary buffers.
        reduced_system = _serinv.ReducedSystem()

        # Initialize dense temporary buffers for the diagonal blocks and
        # the upper and lower auxiliary buffer blocks.
        x_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
        buffer_lower: list[NDArray | None] = [None] * a.num_local_blocks
        buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks

        if obc_blocks is None:
            obc_blocks = OBCBlocks(num_blocks=a.num_local_blocks)

        batch_sizes, batch_offsets = get_batches(a.shape[0], self.max_batch_size)

        for i in range(len(batch_sizes)):
            stack_slice = slice(int(batch_offsets[i]), int(batch_offsets[i + 1]))

            a_ = a.stack[stack_slice]
            out_ = out.stack[stack_slice]

            if comm.block.rank == 0:
                # Direction: downward Schur-complement
                _serinv.downward_schur(
                    a_,
                    x_diag_blocks,
                    obc_blocks,
                    stack_slice=stack_slice,
                    invert_last_block=False,
                )
            elif comm.block.rank == comm.block.size - 1:
                # Direction: upward Schur-complement
                _serinv.upward_schur(
                    a_,
                    x_diag_blocks,
                    obc_blocks,
                    stack_slice=stack_slice,
                    invert_last_block=False,
                )
            else:
                # Permuted Schur-complement
                _serinv.permuted_schur(
                    a_,
                    x_diag_blocks,
                    buffer_lower,
                    buffer_upper,
                    obc_blocks,
                    stack_slice=stack_slice,
                )

            # Construct the reduced system.
            if np.all(a.block_sizes == a.block_sizes[0]):
                gather_reduced_system = reduced_system.gather_constant_block_size
            else:
                # If the block sizes are not the same, we need to use pickle.
                gather_reduced_system = reduced_system.gather

            gather_reduced_system(a_, x_diag_blocks, buffer_upper, buffer_lower)
            # Perform selected-inversion on the reduced system.
            reduced_system.solve()
            # Scatter the result to the output matrix.
            reduced_system.scatter(x_diag_blocks, buffer_upper, buffer_lower, out_)

            if comm.block.rank == 0:
                # Direction: upward sell-inv
                _serinv.downward_selinv(a_, x_diag_blocks, out_)
            elif comm.block.rank == comm.block.size - 1:
                # Direction: downward sell-inv
                _serinv.upward_selinv(a_, x_diag_blocks, out_)
            else:
                # Permuted Sell-inv
                _serinv.permuted_selinv(
                    a_, x_diag_blocks, buffer_lower, buffer_upper, out_
                )

    def selected_solve(
        self,
        a: DSDBSparse,
        sigma_lesser: DSDBSparse,
        sigma_greater: DSDBSparse,
        out: tuple[DSDBSparse, ...],
        obc_blocks: OBCBlocks | None = None,
        return_retarded: bool = False,
    ):
        r"""Performs selected inversion of a block-tridiagonal matrix.

        Can optionally solve the quadratic system associated with the
        Bl and Bg matrices in the equation AXA^T = B.

        Parameters
        ----------
        a : DSDBSparse
            Matrix to invert.
        sigma_lesser : DSDBSparse
            Lesser matrix. This matrix is expected to be
            skew-hermitian, i.e. \(\Sigma_{ij} = -\Sigma_{ji}^*\).
        sigma_greater : DSDBSparse
            Greater matrix. This matrix is expected to be
            skew-hermitian, i.e. \(\Sigma_{ij} = -\Sigma_{ji}^*\).
        out : tuple[DSDBSparse, ...]
            Preallocated output matrices, by default None
        obc_blocks : OBCBlocks, optional
            OBC blocks for lesser, greater and retarded Green's
            functions. By default None.
        return_retarded : bool, optional
            Wether the retarded Green's function should be returned
            along with lesser and greater, by default False

        """

        with profiler.profile_range(label="RGF dist: init", level="default", comm=comm):

            # Initialize temporary buffers.
            reduced_system = _serinv.ReducedSystem(selected_solve=True)

            xr_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
            xr_buffer_lower: list[NDArray | None] = [None] * a.num_local_blocks
            xr_buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks

            xl_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
            xl_buffer_lower = None
            xl_buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks

            xg_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
            xg_buffer_lower = None
            xg_buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks

            if obc_blocks is None:
                obc_blocks = OBCBlocks(num_blocks=a.num_local_blocks)

            xl_out, xg_out, *xr_out = out
            if return_retarded:
                if len(xr_out) != 1:
                    raise ValueError("Invalid number of output matrices.")
                xr_out = xr_out[0]

            batch_sizes, batch_offsets = get_batches(a.shape[0], self.max_batch_size)

            for i in range(len(batch_sizes)):
                stack_slice = slice(int(batch_offsets[i]), int(batch_offsets[i + 1]))

                a_ = a.stack[stack_slice]
                sigma_lesser_ = sigma_lesser.stack[stack_slice]
                sigma_greater_ = sigma_greater.stack[stack_slice]

                xl_out_ = xl_out.stack[stack_slice]
                xg_out_ = xg_out.stack[stack_slice]
                xr_out_ = xr_out.stack[stack_slice] if return_retarded else None

        with profiler.profile_range(
            label="RGF dist: Schur", level="default", comm=comm
        ):

            if comm.block.rank == 0:
                # Direction: downward Schur-complement
                _serinv.downward_schur(
                    a=a_,
                    xr_diag_blocks=xr_diag_blocks,
                    # Lesser quantities.
                    sigma_lesser=sigma_lesser_,
                    xl_diag_blocks=xl_diag_blocks,
                    # Greater quantities.
                    sigma_greater=sigma_greater_,
                    xg_diag_blocks=xg_diag_blocks,
                    # OBC and settings.
                    obc_blocks=obc_blocks,
                    stack_slice=stack_slice,
                    invert_last_block=False,
                    selected_solve=True,
                )
            elif comm.block.rank == comm.block.size - 1:
                # Direction: upward Schur-complement
                _serinv.upward_schur(
                    a=a_,
                    xr_diag_blocks=xr_diag_blocks,
                    # Lesser quantities.
                    sigma_lesser=sigma_lesser_,
                    xl_diag_blocks=xl_diag_blocks,
                    # Greater quantities.
                    sigma_greater=sigma_greater_,
                    xg_diag_blocks=xg_diag_blocks,
                    # OBC and settings.
                    obc_blocks=obc_blocks,
                    stack_slice=stack_slice,
                    invert_last_block=False,
                    selected_solve=True,
                )
            else:
                # Permuted Schur-complement
                _serinv.permuted_schur(
                    a=a_,
                    xr_diag_blocks=xr_diag_blocks,
                    xr_buffer_lower=xr_buffer_lower,
                    xr_buffer_upper=xr_buffer_upper,
                    # Lesser quantities.
                    sigma_lesser=sigma_lesser_,
                    xl_diag_blocks=xl_diag_blocks,
                    xl_buffer_lower=xl_buffer_lower,
                    xl_buffer_upper=xl_buffer_upper,
                    # Greater quantities.
                    sigma_greater=sigma_greater_,
                    xg_diag_blocks=xg_diag_blocks,
                    xg_buffer_lower=xg_buffer_lower,
                    xg_buffer_upper=xg_buffer_upper,
                    # OBC and settings.
                    obc_blocks=obc_blocks,
                    stack_slice=stack_slice,
                    selected_solve=True,
                )

        with profiler.profile_range(
            label="RGF dist: Reduce gather", level="default", comm=comm
        ):
            # Construct the reduced system.
            if np.all(a.block_sizes == a.block_sizes[0]):
                gather_reduced_system = reduced_system.gather_constant_block_size
            else:
                # If the block sizes are not the same, we need to use pickle.
                gather_reduced_system = reduced_system.gather

            gather_reduced_system(
                a=a_,
                xr_diag_blocks=xr_diag_blocks,
                xr_buffer_lower=xr_buffer_lower,
                xr_buffer_upper=xr_buffer_upper,
                # Lesser quantities.
                sigma_lesser=sigma_lesser_,
                xl_diag_blocks=xl_diag_blocks,
                xl_buffer_lower=xl_buffer_lower,
                xl_buffer_upper=xl_buffer_upper,
                # Greater quantities.
                sigma_greater=sigma_greater_,
                xg_diag_blocks=xg_diag_blocks,
                xg_buffer_lower=xg_buffer_lower,
                xg_buffer_upper=xg_buffer_upper,
            )

        # Perform selected-inversion on the reduced system.
        with profiler.profile_range(
            label="RGF dist: Reduce solve", level="default", comm=comm
        ):
            reduced_system.solve()

        with profiler.profile_range(
            label="RGF dist: Reduce scatter", level="default", comm=comm
        ):
            # Scatter the result to the output matrix.
            reduced_system.scatter(
                xr_diag_blocks=xr_diag_blocks,
                xr_buffer_lower=xr_buffer_lower,
                xr_buffer_upper=xr_buffer_upper,
                xr_out=xr_out_,
                return_retarded=return_retarded,
                # Lesser quantities.
                xl_diag_blocks=xl_diag_blocks,
                xl_buffer_lower=xl_buffer_lower,
                xl_buffer_upper=xl_buffer_upper,
                xl_out=xl_out_,
                # Greater quantities.
                xg_diag_blocks=xg_diag_blocks,
                xg_buffer_lower=xg_buffer_lower,
                xg_buffer_upper=xg_buffer_upper,
                xg_out=xg_out_,
            )

        with profiler.profile_range(
            label="RGF dist: Selinv", level="default", comm=comm
        ):

            if comm.block.rank == 0:
                # Direction: upward sell-inv
                _serinv.downward_selinv(
                    a=a_,
                    xr_diag_blocks=xr_diag_blocks,
                    xr_out=xr_out_,
                    # Lesser quantities.
                    sigma_lesser=sigma_lesser_,
                    xl_diag_blocks=xl_diag_blocks,
                    xl_out=xl_out_,
                    # Greater quantities.
                    sigma_greater=sigma_greater_,
                    xg_diag_blocks=xg_diag_blocks,
                    xg_out=xg_out_,
                    selected_solve=True,
                    return_retarded=return_retarded,
                )
            elif comm.block.rank == comm.block.size - 1:
                # Direction: downward sell-inv
                _serinv.upward_selinv(
                    a=a_,
                    xr_diag_blocks=xr_diag_blocks,
                    xr_out=xr_out_,
                    # Lesser quantities.
                    sigma_lesser=sigma_lesser_,
                    xl_diag_blocks=xl_diag_blocks,
                    xl_out=xl_out_,
                    # Greater quantities.
                    sigma_greater=sigma_greater_,
                    xg_diag_blocks=xg_diag_blocks,
                    xg_out=xg_out_,
                    selected_solve=True,
                    return_retarded=return_retarded,
                )
            else:
                # Permuted Sell-inv
                _serinv.permuted_selinv(
                    a=a_,
                    xr_diag_blocks=xr_diag_blocks,
                    xr_buffer_lower=xr_buffer_lower,
                    xr_buffer_upper=xr_buffer_upper,
                    xr_out=xr_out_,
                    # Lesser quantities.
                    sigma_lesser=sigma_lesser_,
                    xl_diag_blocks=xl_diag_blocks,
                    # xl_buffer_lower=xl_buffer_lower,
                    xl_buffer_upper=xl_buffer_upper,
                    xl_out=xl_out_,
                    # Greater quantities.
                    sigma_greater=sigma_greater_,
                    xg_diag_blocks=xg_diag_blocks,
                    # xg_buffer_lower=xg_buffer_lower,
                    xg_buffer_upper=xg_buffer_upper,
                    xg_out=xg_out_,
                    selected_solve=True,
                    return_retarded=return_retarded,
                )