Skip to content

qttools.wave_function_solver.cudss#

[docs] module qttools.wave_function_solver.cudss

# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.

import ctypes.util
from collections.abc import Sequence

from qttools import NDArray, sparse
from qttools.wave_function_solver.solver import WFSolver

try:
    from nvmath.bindings import cudss
    from nvmath.internal import tensor_wrapper, utils
    from nvmath.sparse._internal import cudss_utils
    from nvmath.sparse.advanced import DirectSolver
    from nvmath.sparse.advanced.direct_solver import (
        axis_order_in_memory,
        calculate_strides,
    )

    nvmath_available = True

except ImportError:
    nvmath_available = False


# Possible multithreading libraries for cuDSS in descending order of
# preference.
_mtlayer_libs = [
    "libcudss_mtlayer_gomp.so.0",
    "libcudss_mtlayer_gomp.so.0.5.0",
]


class cuDSS(WFSolver):
    """Wave function solver using cuDSS for sparse matrix solving.

    This solver uses the cuDSS library to solve sparse linear systems
    on NVIDIA GPUs.

    Parameters
    ----------
    explicitely_reset_operands : str, optional
        String indicating which operands to reset explicitly. If "a" is
        in the string, the system matrix `a` will be reset before
        solving. If "b" is in the string, the right-hand side vector `b`
        will be reset before solving. Default is "b", meaning only the
        right-hand side vector will be reset.
    use_multithreading : bool, optional
        Whether to use multithreading for the solver. If True, it will
        attempt to find a suitable multithreading library. Default is
        True.

    """

    def __init__(
        self,
        explicitely_reset_operands: str = "b",
        use_multithreading: bool = True,
    ) -> None:
        """Initializes the cuDSS wave function solver."""
        if not nvmath_available:
            raise ImportError(
                "cuDSS is not available. Please install it to use this solver."
            )

        self.direct_solver = None
        self.plan_info = None
        self.explicitely_reset_a = "a" in explicitely_reset_operands
        self.explicitely_reset_b = "b" in explicitely_reset_operands

        self.solver_options = {}

        if use_multithreading:
            # Try to find a multithreading library for cuDSS.
            multithreading_lib = [
                ctypes.util.find_library(lib) for lib in _mtlayer_libs
            ].pop(0)

            if multithreading_lib is None:
                print("WARNING! No suitable multithreading library found for cuDSS.")
            else:
                self.solver_options["multithreading_lib"] = multithreading_lib

    def __del__(self) -> None:
        """Cleans up the cuDSS solver context."""
        if self.direct_solver is not None:
            self.direct_solver.free()
            self.direct_solver = None

    def _explicitely_reset_b(self, b: NDArray) -> None:
        """Resets the right-hand side in the direct solver.

        This is a workaround for the current limitation in nvmath where
        resetting the right-hand side does not support different shapes
        or strides than the original one.

        Parameters
        ----------
        b : NDArray
            The new right-hand side vector to set.

        """
        ds = self.direct_solver

        stream_holder = utils.get_or_create_stream(
            device_id=ds.device_id, stream=None, op_package=ds.rhs_package
        )

        # Do the checks that nvmath does too.
        # NOTE: Check the direct_solver.py source code for the original
        # implementation.
        explicitly_batched = isinstance(b, Sequence)
        if explicitly_batched:
            raise NotImplementedError(
                "Explicitly batched right-hand sides are not supported."
            )
        b = tensor_wrapper.wrap_operand(b)
        rhs_package = utils.infer_object_package(b.tensor)

        device_id = b.device_id
        memory_space = b.device
        value_type = b.dtype
        shape = b.shape
        strides = b.strides

        # Handle cupy <> numpy asymmetry. See note #2.
        if rhs_package == "numpy":
            rhs_package = "cupy"

        # Check package, device ID, shape, strides, and dtype.
        if rhs_package != ds.rhs_package:
            raise TypeError(
                f"The package for 'b' ({rhs_package}) doesn't match the original one ({ds.rhs_package})."
            )
        if memory_space != ds.memory_space:
            raise TypeError(
                f"The memory space for 'b' ({memory_space}) doesn't match the original one ({ds.memory_space})."
            )
        if device_id != "cpu" and device_id != ds.device_id:
            raise TypeError(
                f"The device id for 'b' ({device_id}) doesn't match the original one ({ds.device_id})."
            )
        if value_type != ds.value_type:
            raise TypeError(
                f"The dtype for 'b' ({value_type}) doesn't match the original one ({ds.value_type})."
            )

        if ds.copy_across_memspace:
            raise NotImplementedError(
                "Copying across memory spaces is not supported for resetting the right-hand side."
            )
        ds.b = b

        # Update the rhs and result shape and strides.
        ds.rhs_shape = shape
        ds.rhs_strides = strides

        ds.result_shape = shape
        # NOTE: I do not know what these two lines do.
        # For single or implicitly-batched RHS, the matrix may not be
        # compact so we use the axis ordering to determine the strides.
        axis_order = axis_order_in_memory(ds.rhs_shape, ds.rhs_strides)
        ds.result_strides = calculate_strides(ds.rhs_shape, axis_order)

        # NOTE: Now it is not enough to just update the pointer
        # references, and keep reference to the internal buffers.
        # Instead, we need to destroy the existing resources and create
        # new ones with the updated right-hand side.

        # Free matrix pointers.
        cudss.matrix_destroy(ds.x_ptr)
        cudss.matrix_destroy(ds.b_ptr)

        ds.resources_b, ds.b_ptr = cudss_utils.create_cudss_dense_wrapper(
            ds.cuda_index_type,
            ds.cuda_value_type,
            ds.index_type,
            ds.batch_indices,
            ds.b,
            stream_holder,
        )
        # Use `b` for creating the (potentially explicitly or implicitly
        # batched) solution matrix or vector. The pointers will be
        # updated later in execute.
        ds.resources_x, ds.x_ptr = cudss_utils.create_cudss_dense_wrapper(
            ds.cuda_index_type,
            ds.cuda_value_type,
            ds.index_type,
            ds.batch_indices,
            ds.b,
            stream_holder,
        )

    def solve(self, a: sparse.spmatrix, b: NDArray) -> NDArray:
        """Solves the sparse linear system a @ x = b using cuDSS.

        Parameters
        ----------
        a : sparse.spmatrix
            The sparse system matrix.
        b : NDArray
            The right-hand side vector.

        Returns
        -------
        x : NDArray
            The solution vector.

        """
        if self.direct_solver is None or self.plan_info is None:
            self.direct_solver = DirectSolver(a, b, options=self.solver_options)
            # NOTE: By default this uses a custom nested dissectioning
            # scheme based on METIS. Other options could in principle be
            # exposed as a parameter.
            self.plan_info = self.direct_solver.plan()

        if self.explicitely_reset_a:
            self.direct_solver.reset_operands(a=a)
            # After resetting a, we need to re-plan.
            self.plan_info = self.direct_solver.plan()

        if self.explicitely_reset_b:
            # TODO: This does not support setting a right-hand side that
            # has a different shape or strides than the original one.
            # Until it is fixed in nvmath, we will hack around it by
            # resetting the operands.
            # self.direct_solver.reset_operands(b=b)
            self._explicitely_reset_b(b)

        self.direct_solver.factorize()
        return self.direct_solver.solve()