Skip to content

qttools.lyapunov.spectral#

[docs] module qttools.lyapunov.spectral

# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.

import warnings

from qttools import NDArray, xp
from qttools.kernels import linalg
from qttools.lyapunov.lyapunov import LyapunovSolver
from qttools.lyapunov.utils import system_reduction
from qttools.profiling import Profiler

profiler = Profiler()


class Spectral(LyapunovSolver):
    """A solver for the Lyapunov equation by using the matrix spectrum.

    Parameters
    ----------
    num_ref_iterations : int, optional
        The number of refinement iterations to perform.
    warning_threshold : float, optional
        The threshold for the relative recursion error to issue a warning.
    eig_compute_location : str, optional
        The location where to compute the eigenvalues and eigenvectors.
        Can be either "numpy" or "cupy". Only relevant if cupy is used.
    reduce_sparsity : bool, optional
        Whether to reduce the sparsity of the system matrix.
    use_pinned_memory : bool, optional
        Whether to use pinnend memory if cupy is used.
        Default is `True`.

    """

    def __init__(
        self,
        num_ref_iterations: int = 3,
        warning_threshold: float = 1e-1,
        eig_compute_location: str = "numpy",
        reduce_sparsity: bool = True,
        use_pinned_memory: bool = True,
    ) -> None:
        """Initializes the spectral Lyapunov solver."""
        self.num_ref_iterations = num_ref_iterations
        self.warning_threshold = warning_threshold
        self.eig_compute_location = eig_compute_location
        self.reduce_sparsity = reduce_sparsity
        self.use_pinned_memory = use_pinned_memory

    @profiler.profile(level="debug")
    def _solve(
        self,
        a: NDArray,
        q: NDArray,
        contact: str,
        out: None | NDArray = None,
    ):
        """Computes the solution of the discrete-time Lyapunov equation.

        Parameters
        ----------
        a : NDArray
            The system matrix.
        q : NDArray
            The right-hand side matrix.
        contact : str
            The contact to which the boundary blocks belong.
        out : NDArray, optional
            The array to store the result in. If not provided, a new
            array is returned.

        Returns
        -------
        x : NDArray | None
            The solution of the discrete-time Lyapunov equation.

        """

        ws, vs = linalg.eig(
            a,
            compute_module=self.eig_compute_location,
            use_pinned_memory=self.use_pinned_memory,
        )

        inv_vs = xp.linalg.inv(vs)
        inv_vs = xp.broadcast_to(inv_vs, q.shape)
        gamma = inv_vs @ q @ inv_vs.conj().swapaxes(-1, -2)

        phi = xp.ones_like(a) - xp.einsum("...i, ...j -> ...ij", ws, ws.conj())
        phi = xp.broadcast_to(phi, q.shape)
        x_tilde = 1 / phi * gamma

        x = vs @ x_tilde @ vs.conj().swapaxes(-1, -2)

        a = xp.broadcast_to(a, q.shape)
        # Perform a number of refinement iterations.
        for __ in range(self.num_ref_iterations - 1):
            x = q + a @ x @ a.conj().swapaxes(-2, -1)

        x_ref = q + a @ x @ a.conj().swapaxes(-2, -1)

        # Check the batch average recursion error.
        recursion_error = xp.max(
            xp.linalg.norm(x_ref - x, axis=(-2, -1))
            / xp.linalg.norm(x_ref, axis=(-2, -1))
        )
        if recursion_error > self.warning_threshold:
            warnings.warn(
                f"High relative recursion error: {recursion_error:.2e}",
                RuntimeWarning,
            )

        if out is not None:
            out[...] = x_ref
            return

        return x_ref

    @profiler.profile(level="api")
    def __call__(
        self,
        a: NDArray,
        q: NDArray,
        contact: str,
        out: None | NDArray = None,
    ) -> NDArray | None:
        """Computes the solution of the discrete-time Lyapunov equation.

        The matrices a and q can have different ndims with q.ndim >= a.ndim (will broadcast)

        Parameters
        ----------
        a : NDArray
            The system matrix.
        q : NDArray
            The right-hand side matrix.
        contact : str
            The contact to which the boundary blocks belong.
        out : NDArray, optional
            The array to store the result in. If not provided, a new
            array is returned.

        Returns
        -------
        x : NDArray | None
            The solution of the discrete-time Lyapunov equation.

        """

        assert q.shape[-2:] == a.shape[-2:]
        assert q.ndim >= a.ndim

        # NOTE: possible to cache the sparsity reduction
        if self.reduce_sparsity:
            return system_reduction(a, q, contact, self._solve, out=out)

        return self._solve(a, q, contact, out=out)