Skip to content

qttools.utils.input_utils#

[docs] module qttools.utils.input_utils

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
# Copyright (c) 2024 ETH Zurich and the authors of the qttools package.

import re
from pathlib import Path

import numpy as np

from qttools import NDArray, _DType, sparse, xp


def read_hr_dat(
    path: Path, return_all: bool = False, dtype: _DType = xp.complex128, read_fast=False
) -> tuple[NDArray, ...]:
    """Parses the contents of a `seedname_hr.dat` file.

    The first line gives the date and time at which the file was
    created. The second line states the number of Wannier functions
    `num_wann`. The third line gives the number of Wigner-Seitz
    grid-points.

    The next block of integers gives the degeneracy of each Wigner-Seitz
    grid point, arranged into 15 values per line.

    Finally, the remaining lines each contain, respectively, the
    components of the Wigner-Seitz cell index, the Wannier center
    indices m and n, and and the real and imaginary parts of the
    Hamiltonian matrix element `HRmn` in the localized basis.

    Parameters
    ----------
    path : Path
        Path to a `seedname_hr.dat` file.
    return_all : bool, optional
        Whether to return all the data or just the Hamiltonian in the
        localized basis. When `True`, the degeneracies and the
        Wigner-Seitz cell indices are also returned. Defaults to
        `False`.
    dtype : dtype, optional
        The data type of the Hamiltonian matrix elements. Defaults to
        `numpy.complex128`.
    read_fast : bool, optional
        Whether to assume that the file is well-formatted and all the
        data is sorted correctly. Defaults to `False`.

    Returns
    -------
    hr : ndarray
        The Hamiltonian matrix elements in the localized basis.
    degeneracies : ndarray, optional
        The degeneracies of the Wigner-Seitz grid points.
    R : ndarray, optional
        The Wigner-Seitz cell indices.

    """

    # Strip info from header.
    num_wann, nrpts = xp.loadtxt(path, skiprows=1, max_rows=2, dtype=int)
    num_wann, nrpts = int(num_wann), int(nrpts)

    # Read wannier data (skipping degeneracy info).
    deg_rows = int(xp.ceil(nrpts / 15.0))
    wann_dat = xp.loadtxt(path, skiprows=3 + deg_rows)

    # Assign R
    if read_fast:
        R = wann_dat[:: num_wann**2, :3].astype(int)
    else:
        R = wann_dat[:, :3].astype(int)
    Rs = xp.subtract(R, R.min(axis=0))
    N1, N2, N3 = Rs.max(axis=0) + 1
    N1, N2, N3 = int(N1), int(N2), int(N3)

    # Obtain Hamiltonian elements.
    if read_fast:
        hR = wann_dat[:, 5] + 1j * wann_dat[:, 6]
        hR = hR.reshape(N1, N2, N3, num_wann, num_wann).swapaxes(-2, -1)
        hR = xp.roll(hR, shift=(N1 // 2 + 1, N2 // 2 + 1, N3 // 2 + 1), axis=(0, 1, 2))
    else:
        hR = xp.zeros((N1, N2, N3, num_wann, num_wann), dtype=dtype)
        for line in wann_dat:
            R1, R2, R3 = line[:3].astype(int)
            m, n = line[3:5].astype(int)
            hR_mn_real, hR_mn_imag = line[5:]
            hR[R1, R2, R3, m - 1, n - 1] = hR_mn_real + 1j * hR_mn_imag

    if return_all:
        return hR, xp.unique(R, axis=0)
    return hR


def read_wannier_wout(
    path: Path, transform_home_cell: bool = True
) -> tuple[NDArray, NDArray]:
    """Parses the contents of a `seedname.wout` file and returns the Wannier centers and lattice vectors.

    TODO: Add tests.

    Parameters
    ----------
    path : Path
        Path to a `seedname.wout` file.
    transform_home_cell : bool, optional
        Whether to transform the Wannier centers to the home cell. Defaults to `True`.

    Returns
    -------
    wannier_centers : ndarray
        The Wannier centers.
    lattice_vectors : ndarray
        The lattice vectors.
    """
    with open(path, "r") as f:
        lines = f.readlines()

    num_lines = len(lines)

    # Find the line with the lattice vectors.
    for i, line in enumerate(lines):
        if "Lattice Vectors" in line:
            lattice_vectors = xp.asarray(
                [list(map(float, lines[i + j + 1].split()[1:])) for j in range(3)]
            )
        if "Number of Wannier Functions" in line:
            num_wann = int(line.split()[-2])
            break

    # Find the line with the Wannier centers. Start from the end of the file.
    for i, line in enumerate(lines[::-1]):
        if "Final State" in line:
            # The Wannier centers are enclosed by parantheses, so we have to extract them.
            wannier_centers = xp.asarray(
                [
                    list(
                        map(
                            float,
                            re.findall(r"\((.*?)\)", lines[num_lines - i + j])[0].split(
                                ","
                            ),
                        )
                    )
                    for j in range(num_wann)
                ]
            )
            break

    if transform_home_cell:
        # Get the transformation that diagonalize the lattice vectors
        transformation = xp.linalg.inv(lattice_vectors)
        # Appy it to the wannier centers
        wannier_centers = xp.dot(wannier_centers, transformation)
        # Translate the Wannier centers to the home cell
        wannier_centers = xp.mod(wannier_centers, 1)
        # Transform the Wannier centers back to the original basis
        wannier_centers = xp.dot(wannier_centers, lattice_vectors)

    return wannier_centers, lattice_vectors


def cutoff_hr(
    hr: NDArray,
    value_cutoff: float | None = None,
    R_cutoff: int | tuple[int, int, int] | None = None,
    remove_zeros: bool = False,
) -> NDArray:
    """Cutoffs the Hamiltonian matrix elements based on their values and/or the wigner-seitz cell indices.

    TODO: Add tests.

    Parameters
    ----------
    hr : ndarray
        Wannier Hamiltonian.
    value_cutoff : float, optional
        Cutoff value for the Hamiltonian. Defaults to `None`.
    R_cutoff : int or tuple, optional
        Cutoff distance for the Hamiltonian. Defaults to `None`.
    remove_zeros : bool, optional
        Whether to remove cell planes with only zeros. Defaults to `False`.

    Returns
    -------
    ndarray
        The cutoff Hamiltonian.
    """
    hr_cut = None
    if value_cutoff is None and R_cutoff is None:
        return hr.copy()
    if R_cutoff is not None:
        if isinstance(R_cutoff, int):
            R_cutoff = (R_cutoff, R_cutoff, R_cutoff)
        hr_cut = xp.zeros_like(hr)
        for ind in xp.ndindex(hr.shape[:3]):
            ind = xp.asarray(ind) - xp.asarray(hr.shape[:3]) // 2
            if (abs(ind) <= xp.asarray(R_cutoff)).all():
                hr_cut[*ind] = hr[*ind]
    if value_cutoff is not None:
        if hr_cut is None:
            hr_cut = hr.copy()
        hr_cut[xp.abs(hr_cut) > value_cutoff] = 0

    # Remove eventual cell planes with only zeros, except the center.
    if remove_zeros:
        zero_mask = hr_cut.any(axis=(-2, -1))
        for ax in range(3):  # Loop through axes (0, 1, 2)
            # Loop backwards through the axis (from the edge to the center).
            for idx in range(hr_cut.shape[ax] // 2, 0, -1):
                axes_to_remove = []
                # Check if all elements are False in the cell plane.
                if not zero_mask.take(idx, axis=ax).any():
                    # If so, remove it.
                    axes_to_remove.append(idx)
                elif not zero_mask.take(-idx, axis=ax).any():
                    axes_to_remove.append(-idx)
                else:
                    # If not, break the loop (to not mess with ordering incase zero planes are not at the edge).
                    break
                hr_cut = xp.delete(hr_cut, axes_to_remove, axis=ax)

    return hr_cut


def get_hamiltonian_block(
    hr: NDArray,
    supercell_size: tuple,
    global_shift: tuple,
) -> xp.ndarray:
    """Constructs a supercell hamiltonian block from an hr array.

    Parameters
    ----------
    hr : ndarray
        Wannier Hamiltonian.
    supercell_size : tuple
        Size of the supercell. E.g. (2, 2, 1) for a 2x2 xy-supercell.
    global_shift : tuple
        Shift in the supercell system. If you want a
        R-shift of 1 cell in x direction, you would pass (1, 0,
        0). NOTE: this is for the supercell and NOT the unit cell.

    Returns
    -------
    ndarray
        The supercell hamiltonian block.

    """
    local_shifts = xp.asarray(list(xp.ndindex(supercell_size)))
    # Transform to NDArrays (because of cupy multiply).
    if not isinstance(supercell_size, xp.ndarray):
        supercell_size = xp.asarray(supercell_size)
    if not isinstance(global_shift, xp.ndarray):
        global_shift = xp.asarray(global_shift)
    global_shift = xp.multiply(global_shift, supercell_size)

    rows = []
    for r_i in local_shifts:
        row = []
        for r_j in local_shifts:
            ind = tuple(r_j - r_i + global_shift)
            try:
                if any(abs(i) > hr.shape[j] // 2 for j, i in enumerate(ind)):
                    raise IndexError
                block = hr[ind]
            except IndexError:
                block = xp.zeros(hr.shape[-2:], dtype=hr.dtype)
            row.append(block)
        rows.append(xp.hstack(row))
    return xp.vstack(rows)


def create_coordinate_grid(
    wannier_centers: NDArray, super_cell: tuple, lattice_vectors: NDArray
) -> NDArray:
    """Creates a grid of coordinates for Wannier functions in a supercell."""
    num_wann = wannier_centers.shape[0]
    grid = xp.zeros(
        (int(xp.prod(xp.asarray(super_cell)) * num_wann), 3), dtype=xp.float64
    )
    for i, cell_ind in enumerate(np.ndindex(super_cell)):
        grid[i * num_wann : (i + 1) * num_wann, :] = (
            wannier_centers + xp.asarray(cell_ind) @ lattice_vectors
        )
    return grid


def create_hamiltonian(
    hR: NDArray,
    num_transport_cells: int,
    transport_dir: int | str = "x",
    transport_cell: list = None,
    block_start: int = None,
    block_end: int = None,
    return_sparse: bool = True,
    cutoff: float = xp.inf,
    coords: NDArray = None,
    lattice_vectors: NDArray = None,
) -> list[NDArray]:
    """Creates a block-tridiagonal Hamiltonian matrix from a Wannier Hamiltonian.
    The transport cell (same as supercell) is the cell that is repeated in the transport direction,
    and is only connected to nearest-neighboring cells. NOTE: interactions outside
    nearest neighbors are not included in the block-tridiagonal Hamiltonian (see below).
    It can therefore be important to make sure that the transport cell is large enough, such that
    each row have the same number of neighbouring cells. Not setting a transport cell will default
    to a cell that includes all interactions of hR.

      ------- -------
     | o o o | o o o | x
     | o o o | o o o | x x  <- cells outside nearest neighbors are not included
     | o o o | o o o | x x x
      ------- ------- -------
     | o o o | o o o | o o o |
     | o o o | o o o | o o o |
     | o o o | o o o | o o o |
      ------- ------- -------
       x x x | o o o | o o o |
         x x | o o o | o o o |
           x | o o o | o o o |
              ------- -------

    Parameters
    ----------
    hR : ndarray
        Wannier Hamiltonian.
    num_transport_cells : int
        Number of transport cells.
    transport_dir : int or str, optional
        Direction of transport. Can be 0, 1, 2, 'x', 'y', or 'z'.
    transport_cell : tuple, optional
        Size of the transport cell. E.g. [2, 2, 1] for a 2x2 xy-transport cell.
    block_start : int, optional
        Starting block index for arrow shape partition. Defaults to `None`.
    block_end : int, optional
        Ending block index for arrow shape partition. Defaults to `None`.
    return_sparse : bool, optional
        Whether to return the block-tridiagonal Hamiltonian as a sparse matrix. Defaults to `False`.
    cutoff : float, optional
        Cutoff distance for connections between wannier functions. Defaults to `np.inf`.
    coords : ndarray, optional
        Coordinates of the Wannier functions in a unit cell. Defaults to `None`.
    lattice_vectors : ndarray, optional
        Lattice vectors of the system. Defaults to `None`.

    Returns
    -------
    list[ndarray] or tuple[sparse.coo_matrix, ndarray]
        The block-tridiagonal Hamiltonian matrix as either a tuple of arrays or a sparse matrix and block sizes.
    """
    if cutoff is not xp.inf and coords is None and lattice_vectors is None:
        print(
            "Cutoff is set but coords and lattice_vectors are not provided. No cutoff will be applied.",
            flush=True,
        )

    if isinstance(transport_dir, str):
        transport_dir = "xyz".index(transport_dir)

    if transport_cell is None:
        # NOTE: Can also do without the + 1.
        transport_cell = tuple(
            [
                shape // 2 + 1 if i == transport_dir else 1
                for i, shape in enumerate(hR.shape[:3])
            ]
        )

    block_start = block_start or 0
    block_end = block_end or num_transport_cells
    if block_start >= block_end:
        raise ValueError("block_start must be smaller than block_end.")
    if block_end > num_transport_cells:
        raise ValueError("block_end must be smaller than num_transport_cells.")
    if block_start < 0:
        raise ValueError("block_start must be greater than or equal to 0.")

    upper_ind = tuple([1 if i == transport_dir else 0 for i in range(3)])
    lower_ind = tuple([-1 if i == transport_dir else 0 for i in range(3)])

    diag_block = get_hamiltonian_block(hR, transport_cell, (0, 0, 0))
    upper_block = get_hamiltonian_block(hR, transport_cell, upper_ind)
    lower_block = get_hamiltonian_block(hR, transport_cell, lower_ind)

    # Enforce cutoff.
    if coords is not None and cutoff < xp.inf and lattice_vectors is not None:
        super_cell_coords = create_coordinate_grid(
            coords, transport_cell, lattice_vectors
        )
        distance_matrix = xp.diagonal(
            xp.subtract.outer(super_cell_coords, super_cell_coords), axis1=1, axis2=3
        )
        diag_dist = xp.linalg.norm(distance_matrix, axis=-1)
        upper_dist = xp.linalg.norm(
            distance_matrix + xp.asarray(upper_ind) @ lattice_vectors, axis=-1
        )
        lower_dist = xp.linalg.norm(
            distance_matrix + xp.asarray(lower_ind) @ lattice_vectors, axis=-1
        )
        diag_block[diag_dist > cutoff] = 0
        upper_block[upper_dist > cutoff] = 0
        lower_block[lower_dist > cutoff] = 0

    if return_sparse:
        # Create sparse matrices of the blocks.
        diag_block = sparse.coo_matrix(diag_block)
        upper_block = sparse.coo_matrix(upper_block)
        lower_block = sparse.coo_matrix(lower_block)
        # Canoncialize the sparse matrices.
        # NOTE: Not sure if this is necessary.
        for mat in [diag_block, upper_block, lower_block]:
            if mat.has_canonical_format is False:
                mat.sum_duplicates()
        # Create the block-tridiagonal matrix.
        num_blocks = block_end - block_start
        offsets = xp.arange(block_start, block_end) * diag_block.shape[0]

        def _tile_sparse_blocks(block, num_blocks, offsets):
            return (
                xp.tile(block.row, num_blocks) + xp.repeat(offsets, block.nnz),
                xp.tile(block.col, num_blocks) + xp.repeat(offsets, block.nnz),
                xp.tile(block.data, num_blocks),
            )

        diag_rows, diag_cols, diag_data = _tile_sparse_blocks(
            diag_block, num_blocks, offsets
        )
        upper_rows, upper_cols, upper_data = _tile_sparse_blocks(
            upper_block, num_blocks, offsets
        )
        lower_rows, lower_cols, lower_data = _tile_sparse_blocks(
            lower_block, num_blocks, offsets
        )
        upper_cols += diag_block.shape[0]
        lower_rows += diag_block.shape[0]

        full_rows = xp.hstack([diag_rows, upper_rows, lower_rows])
        full_cols = xp.hstack([diag_cols, upper_cols, lower_cols])
        full_data = xp.hstack([diag_data, upper_data, lower_data])
        # Remove the fishtail at the end of the matrix.
        matrix_shape = num_transport_cells * diag_block.shape[0]
        valid_mask = (full_cols < matrix_shape) & (full_rows < matrix_shape)
        full_rows = full_rows[valid_mask]
        full_cols = full_cols[valid_mask]
        full_data = full_data[valid_mask]
        # Also return the block sizes.
        block_sizes = xp.ones(num_blocks, dtype=int) * diag_block.shape[0]
        return (
            sparse.coo_matrix(
                (full_data, (full_rows, full_cols)),
                shape=(matrix_shape, matrix_shape),
            ),
            block_sizes,
        )
    else:
        # Returns the block-tridiagonal Hamiltonian matrix as a tuple of arrays.
        diag = xp.tile(diag_block, (block_end - block_start, 1))
        upper = xp.tile(
            upper_block,
            (min(block_end + 1, num_transport_cells) - (block_start + 1), 1),
        )
        lower = xp.tile(
            lower_block,
            (min(block_end + 1, num_transport_cells) - (block_start + 1), 1),
        )

        return diag, upper, lower