Skip to content

qttools.fft.ffts#

[docs] module qttools.fft.ffts

# Copyright (c) 2024-2026 ETH Zurich and the authors of the qttools package.

from qttools import NDArray, xp

if xp.__name__ == "cupy":
    cache = xp.fft.config.get_plan_cache()


def fft_convolve(a: NDArray, b: NDArray) -> NDArray:
    """Computes the convolution of two arrays using FFT over the first axis (energy axis).

    Parameters
    ----------
    a : NDArray
        First array.
    b : NDArray
        Second array.

    Returns
    -------
    NDArray
        The convolution of the two arrays.
    """
    ne = (
        a.shape[0] + b.shape[0]
    )  # Should not have -1 here (otherwise hilbert transform fails)
    a_fft = xp.fft.fft(a, ne, axis=0)
    b_fft = xp.fft.fft(b, ne, axis=0)
    return xp.fft.ifft(a_fft * b_fft, axis=0)


def fft_circular_convolve(a: NDArray, b: NDArray, axes: tuple[int]) -> NDArray:
    """Computes the circular convolution of two arrays using the FFT.

    Parameters
    ----------
    a : NDArray
        First array.
    b : NDArray
        Second array.
    axes : tuple[int]
        The axes over which to perform the convolution.

    Returns
    -------
    NDArray
        The circular convolution of the two arrays.
    """
    # Extract the shapes of the arrays along the axes as tuples.
    nka = tuple(a.shape[i] for i in axes)
    nkb = tuple(b.shape[i] for i in axes)
    a_fft = xp.fft.fftn(a, nka, axes=axes)
    b_fft = xp.fft.fftn(b, nkb, axes=axes)
    return xp.fft.ifftn(a_fft * b_fft, axes=axes)


def fft_convolve_kpoints(a: NDArray, b: NDArray) -> NDArray:
    """Computes the convolution of two arrays using the FFT.

    The first axis is assumed to be the energy axis, the other
    axes are k-points and the last axis is the orbital index.

    Over the k-point axes, a circular convolution is performed.

    Parameters
    ----------
    a : NDArray
        First array.
    b : NDArray
        Second array.
    Returns
    -------
    NDArray
        The convolution of the two arrays.
    """
    ne = a.shape[0] + b.shape[0] - 1
    nka = a.shape[1:-1]
    nkb = b.shape[1:-1]
    a_fft = xp.fft.fftn(a, (ne,) + nka, axes=(0,) + tuple(range(1, len(nka) + 1)))
    b_fft = xp.fft.fftn(b, (ne,) + nkb, axes=(0,) + tuple(range(1, len(nkb) + 1)))
    return xp.fft.ifftn(a_fft * b_fft, axes=(0,) + tuple(range(1, len(nka) + 1)))


def fft_correlate_kpoints(a: NDArray, b: NDArray) -> NDArray:
    """Computes the correlation of two arrays using FFT.

    The first axis is assumed to be the energy axis, the other
    axes are k-points and the last axis is the orbital index.

    Over the k-point axes, a circular correlation is performed.

    Parameters
    ----------
    a : NDArray
        First array.
    b : NDArray
        Second array.

    Returns
    -------
    NDArray
        The cross-correlation of the two arrays.

    """
    ne = a.shape[0] + b.shape[0] - 1
    nka = a.shape[1:-1]
    nkb = b.shape[1:-1]
    a_fft = xp.fft.fftn(a, (ne,) + nka, axes=(0,) + tuple(range(1, len(nka) + 1)))
    b_fft = xp.fft.fftn(
        xp.flip(b, axis=(0,) + tuple(range(1, len(nkb) + 1))),
        (ne,) + nkb,
        axes=(0,) + tuple(range(1, len(nkb) + 1)),
    )
    # I have to roll the result by one, to get the gamma point at index 0.
    # The ordering is then: [0, 1, 2, ..., -2, -1] in terms of shifts.
    return xp.roll(
        xp.fft.ifftn(a_fft * b_fft, axes=(0,) + tuple(range(1, len(nka) + 1))),
        shift=1,
        axis=tuple(range(1, len(nka) + 1)),
    )