# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.
import numpy as np
from qttools import NDArray
from qttools.comm import comm
from qttools.datastructures.dsdbsparse import DSDBSparse
from qttools.greens_function_solver import _serinv
from qttools.greens_function_solver.solver import GFSolver, OBCBlocks
from qttools.profiling import Profiler
from qttools.utils.solvers_utils import get_batches
profiler = Profiler()
class RGFDist(GFSolver):
"""Distributed selected inversion solver.
Parameters
----------
solve_lesser : bool, optional
Whether to solve the quadratic system associated with the lesser right-hand-side,
by default False.
solve_greater : bool, optional
Whether to solve the quadratic system associated with the greater right-hand-side,
by default False.
max_batch_size : int, optional
Maximum batch size to use when inverting the matrix, by default
100.
"""
def __init__(self, max_batch_size: int = 100) -> None:
"""Initializes the selected inversion solver."""
self.max_batch_size = max_batch_size
def selected_inv(
self,
a: DSDBSparse,
out: DSDBSparse,
obc_blocks: OBCBlocks | None = None,
) -> None | DSDBSparse:
"""Performs selected inversion of a block-tridiagonal matrix.
Parameters
----------
a : DSDBSparse
Matrix to invert.
out : DSDBSparse, optional
Preallocated output matrix, by default None.
Returns
-------
None | DSDBSparse
If `out` is None, returns None. Otherwise, returns the
inverted matrix as a DSDBSparse object.
"""
# Initialize temporary buffers.
reduced_system = _serinv.ReducedSystem()
# Initialize dense temporary buffers for the diagonal blocks and
# the upper and lower auxiliary buffer blocks.
x_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
buffer_lower: list[NDArray | None] = [None] * a.num_local_blocks
buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks
if obc_blocks is None:
obc_blocks = OBCBlocks(num_blocks=a.num_local_blocks)
batch_sizes, batch_offsets = get_batches(a.shape[0], self.max_batch_size)
for i in range(len(batch_sizes)):
stack_slice = slice(int(batch_offsets[i]), int(batch_offsets[i + 1]))
a_ = a.stack[stack_slice]
out_ = out.stack[stack_slice]
if comm.block.rank == 0:
# Direction: downward Schur-complement
_serinv.downward_schur(
a_,
x_diag_blocks,
obc_blocks,
stack_slice=stack_slice,
invert_last_block=False,
)
elif comm.block.rank == comm.block.size - 1:
# Direction: upward Schur-complement
_serinv.upward_schur(
a_,
x_diag_blocks,
obc_blocks,
stack_slice=stack_slice,
invert_last_block=False,
)
else:
# Permuted Schur-complement
_serinv.permuted_schur(
a_,
x_diag_blocks,
buffer_lower,
buffer_upper,
obc_blocks,
stack_slice=stack_slice,
)
# Construct the reduced system.
if np.all(a.block_sizes == a.block_sizes[0]):
gather_reduced_system = reduced_system.gather_constant_block_size
else:
# If the block sizes are not the same, we need to use pickle.
gather_reduced_system = reduced_system.gather
gather_reduced_system(a_, x_diag_blocks, buffer_upper, buffer_lower)
# Perform selected-inversion on the reduced system.
reduced_system.solve()
# Scatter the result to the output matrix.
reduced_system.scatter(x_diag_blocks, buffer_upper, buffer_lower, out_)
if comm.block.rank == 0:
# Direction: upward sell-inv
_serinv.downward_selinv(a_, x_diag_blocks, out_)
elif comm.block.rank == comm.block.size - 1:
# Direction: downward sell-inv
_serinv.upward_selinv(a_, x_diag_blocks, out_)
else:
# Permuted Sell-inv
_serinv.permuted_selinv(
a_, x_diag_blocks, buffer_lower, buffer_upper, out_
)
def selected_solve(
self,
a: DSDBSparse,
sigma_lesser: DSDBSparse,
sigma_greater: DSDBSparse,
out: tuple[DSDBSparse, ...],
obc_blocks: OBCBlocks | None = None,
return_retarded: bool = False,
):
r"""Performs selected inversion of a block-tridiagonal matrix.
Can optionally solve the quadratic system associated with the
Bl and Bg matrices in the equation AXA^T = B.
Parameters
----------
a : DSDBSparse
Matrix to invert.
sigma_lesser : DSDBSparse
Lesser matrix. This matrix is expected to be
skew-hermitian, i.e. \(\Sigma_{ij} = -\Sigma_{ji}^*\).
sigma_greater : DSDBSparse
Greater matrix. This matrix is expected to be
skew-hermitian, i.e. \(\Sigma_{ij} = -\Sigma_{ji}^*\).
out : tuple[DSDBSparse, ...]
Preallocated output matrices, by default None
obc_blocks : OBCBlocks, optional
OBC blocks for lesser, greater and retarded Green's
functions. By default None.
return_retarded : bool, optional
Wether the retarded Green's function should be returned
along with lesser and greater, by default False
"""
with profiler.profile_range(label="RGF dist: init", level="default", comm=comm):
# Initialize temporary buffers.
reduced_system = _serinv.ReducedSystem(selected_solve=True)
xr_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
xr_buffer_lower: list[NDArray | None] = [None] * a.num_local_blocks
xr_buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks
xl_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
xl_buffer_lower = None
xl_buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks
xg_diag_blocks: list[NDArray | None] = [None] * a.num_local_blocks
xg_buffer_lower = None
xg_buffer_upper: list[NDArray | None] = [None] * a.num_local_blocks
if obc_blocks is None:
obc_blocks = OBCBlocks(num_blocks=a.num_local_blocks)
xl_out, xg_out, *xr_out = out
if return_retarded:
if len(xr_out) != 1:
raise ValueError("Invalid number of output matrices.")
xr_out = xr_out[0]
batch_sizes, batch_offsets = get_batches(a.shape[0], self.max_batch_size)
for i in range(len(batch_sizes)):
stack_slice = slice(int(batch_offsets[i]), int(batch_offsets[i + 1]))
a_ = a.stack[stack_slice]
sigma_lesser_ = sigma_lesser.stack[stack_slice]
sigma_greater_ = sigma_greater.stack[stack_slice]
xl_out_ = xl_out.stack[stack_slice]
xg_out_ = xg_out.stack[stack_slice]
xr_out_ = xr_out.stack[stack_slice] if return_retarded else None
with profiler.profile_range(
label="RGF dist: Schur", level="default", comm=comm
):
if comm.block.rank == 0:
# Direction: downward Schur-complement
_serinv.downward_schur(
a=a_,
xr_diag_blocks=xr_diag_blocks,
# Lesser quantities.
sigma_lesser=sigma_lesser_,
xl_diag_blocks=xl_diag_blocks,
# Greater quantities.
sigma_greater=sigma_greater_,
xg_diag_blocks=xg_diag_blocks,
# OBC and settings.
obc_blocks=obc_blocks,
stack_slice=stack_slice,
invert_last_block=False,
selected_solve=True,
)
elif comm.block.rank == comm.block.size - 1:
# Direction: upward Schur-complement
_serinv.upward_schur(
a=a_,
xr_diag_blocks=xr_diag_blocks,
# Lesser quantities.
sigma_lesser=sigma_lesser_,
xl_diag_blocks=xl_diag_blocks,
# Greater quantities.
sigma_greater=sigma_greater_,
xg_diag_blocks=xg_diag_blocks,
# OBC and settings.
obc_blocks=obc_blocks,
stack_slice=stack_slice,
invert_last_block=False,
selected_solve=True,
)
else:
# Permuted Schur-complement
_serinv.permuted_schur(
a=a_,
xr_diag_blocks=xr_diag_blocks,
xr_buffer_lower=xr_buffer_lower,
xr_buffer_upper=xr_buffer_upper,
# Lesser quantities.
sigma_lesser=sigma_lesser_,
xl_diag_blocks=xl_diag_blocks,
xl_buffer_lower=xl_buffer_lower,
xl_buffer_upper=xl_buffer_upper,
# Greater quantities.
sigma_greater=sigma_greater_,
xg_diag_blocks=xg_diag_blocks,
xg_buffer_lower=xg_buffer_lower,
xg_buffer_upper=xg_buffer_upper,
# OBC and settings.
obc_blocks=obc_blocks,
stack_slice=stack_slice,
selected_solve=True,
)
with profiler.profile_range(
label="RGF dist: Reduce gather", level="default", comm=comm
):
# Construct the reduced system.
if np.all(a.block_sizes == a.block_sizes[0]):
gather_reduced_system = reduced_system.gather_constant_block_size
else:
# If the block sizes are not the same, we need to use pickle.
gather_reduced_system = reduced_system.gather
gather_reduced_system(
a=a_,
xr_diag_blocks=xr_diag_blocks,
xr_buffer_lower=xr_buffer_lower,
xr_buffer_upper=xr_buffer_upper,
# Lesser quantities.
sigma_lesser=sigma_lesser_,
xl_diag_blocks=xl_diag_blocks,
xl_buffer_lower=xl_buffer_lower,
xl_buffer_upper=xl_buffer_upper,
# Greater quantities.
sigma_greater=sigma_greater_,
xg_diag_blocks=xg_diag_blocks,
xg_buffer_lower=xg_buffer_lower,
xg_buffer_upper=xg_buffer_upper,
)
# Perform selected-inversion on the reduced system.
with profiler.profile_range(
label="RGF dist: Reduce solve", level="default", comm=comm
):
reduced_system.solve()
with profiler.profile_range(
label="RGF dist: Reduce scatter", level="default", comm=comm
):
# Scatter the result to the output matrix.
reduced_system.scatter(
xr_diag_blocks=xr_diag_blocks,
xr_buffer_lower=xr_buffer_lower,
xr_buffer_upper=xr_buffer_upper,
xr_out=xr_out_,
return_retarded=return_retarded,
# Lesser quantities.
xl_diag_blocks=xl_diag_blocks,
xl_buffer_lower=xl_buffer_lower,
xl_buffer_upper=xl_buffer_upper,
xl_out=xl_out_,
# Greater quantities.
xg_diag_blocks=xg_diag_blocks,
xg_buffer_lower=xg_buffer_lower,
xg_buffer_upper=xg_buffer_upper,
xg_out=xg_out_,
)
with profiler.profile_range(
label="RGF dist: Selinv", level="default", comm=comm
):
if comm.block.rank == 0:
# Direction: upward sell-inv
_serinv.downward_selinv(
a=a_,
xr_diag_blocks=xr_diag_blocks,
xr_out=xr_out_,
# Lesser quantities.
sigma_lesser=sigma_lesser_,
xl_diag_blocks=xl_diag_blocks,
xl_out=xl_out_,
# Greater quantities.
sigma_greater=sigma_greater_,
xg_diag_blocks=xg_diag_blocks,
xg_out=xg_out_,
selected_solve=True,
return_retarded=return_retarded,
)
elif comm.block.rank == comm.block.size - 1:
# Direction: downward sell-inv
_serinv.upward_selinv(
a=a_,
xr_diag_blocks=xr_diag_blocks,
xr_out=xr_out_,
# Lesser quantities.
sigma_lesser=sigma_lesser_,
xl_diag_blocks=xl_diag_blocks,
xl_out=xl_out_,
# Greater quantities.
sigma_greater=sigma_greater_,
xg_diag_blocks=xg_diag_blocks,
xg_out=xg_out_,
selected_solve=True,
return_retarded=return_retarded,
)
else:
# Permuted Sell-inv
_serinv.permuted_selinv(
a=a_,
xr_diag_blocks=xr_diag_blocks,
xr_buffer_lower=xr_buffer_lower,
xr_buffer_upper=xr_buffer_upper,
xr_out=xr_out_,
# Lesser quantities.
sigma_lesser=sigma_lesser_,
xl_diag_blocks=xl_diag_blocks,
# xl_buffer_lower=xl_buffer_lower,
xl_buffer_upper=xl_buffer_upper,
xl_out=xl_out_,
# Greater quantities.
sigma_greater=sigma_greater_,
xg_diag_blocks=xg_diag_blocks,
# xg_buffer_lower=xg_buffer_lower,
xg_buffer_upper=xg_buffer_upper,
xg_out=xg_out_,
selected_solve=True,
return_retarded=return_retarded,
)