Source code for qcd_ml.util.solver

"""
qcd_ml.util.solver
==================

Solvers for systems of linear equations.
"""
import torch 
import numpy as np



[docs] def update_qr(H, s, c, j): """ Runs and updates the QR decomposition of the matrix H. This function is used internally by GMRES_inner. """ # Apply previous Givens rotations to the new column of H for i in range(j): tmp = -s[i] * H[i, j] + c[i] * H[i + 1, j] H[i, j] = np.conjugate(c[i]) * H[i, j] + np.conjugate(s[i]) * H[i + 1, j] H[i + 1, j] = tmp # Compute the new Givens rotation beta = np.sqrt(np.abs(H[j,j])**2 + np.abs(H[j+1,j])**2) s[j] = H[j+1,j] / beta c[j] = H[j,j] / beta H[j,j] = beta H[j+1,j] = 0.0
[docs] def update_result(x, Z, gamma, H, y, j): """ Updates the result of GMRES_inner by going from the Krylov space (spanned by Z, coefficients H and gamma) to the solution x. """ for i in reversed(range(j + 1)): y[i] = (gamma[i] - np.dot(H[i, i+1:j+1], y[i+1:j+1])) / H[i,i] for i in range(j+1): x += y[i] * Z[i] return x
[docs] def GMRES_inner(A, b, x0, stopat_residual, niterations, innerproduct, preconditioner): """ Inner GMRES, i.e., ``niterations`` without restart. """ r0 = b - A(x0) v1 = r0 / innerproduct(r0, r0) ** 0.5 x = x0 H = np.zeros((niterations + 1, niterations), dtype=np.complex128) s = np.zeros(niterations + 1, dtype=np.complex128) c = np.zeros(niterations + 1, dtype=np.complex128) y = np.zeros(niterations + 1, np.complex128) gamma = np.zeros(niterations + 1, dtype=np.complex128) gamma[0] = innerproduct(r0, r0) ** 0.5 history = np.zeros(niterations) V = [v1] + [None] * (niterations) if preconditioner is not None: Z = [None] * (niterations) Z_or_V = Z else: Z_or_V = V breakdown = False converged = False for j in range(niterations): if preconditioner is not None: Z_or_V[j] = preconditioner(V[j]) Avj = A(Z_or_V[j]) for i in range(j + 1): H[i, j] = innerproduct(V[i], Avj) vjp1_hat = Avj for i in range(j+1): vjp1_hat = vjp1_hat - H[i, j] * V[i] H[j + 1, j] = np.abs(innerproduct(vjp1_hat, vjp1_hat)) ** 0.5 if H[j + 1, j] == 0.0: breakdown = True break v_jp1 = vjp1_hat / H[j + 1, j] V[j + 1] = v_jp1 update_qr(H, s, c, j) gamma[j + 1] = - s[j] * gamma[j] gamma[j] = np.conj(c[j]) * gamma[j] res = np.abs(gamma[j+1]) history[j] = res if res < stopat_residual: converged = True break x = update_result(x, Z_or_V, gamma, H, y, j) return x, {"converged": converged, "breakdown": breakdown, "res": res, "k": j + 1, "target_residual": stopat_residual, "history": history}
[docs] def GMRES(A, b, x0 , maxiter=1000 , inner_iter=30 , eps=1e-5 , innerproduct=lambda x,y: (x.conj() * y).sum() , preconditioner=None , verbose=False ): """ Implementation of thr GMRES algorithm for solving the linear system Ax = b. ``A``: callable or a matrix that allows ``A @ x`` to be computed. ``b``: right-hand side of the linear system. ``x0``: initial guess for the solution. ``maxiter``: maximum number of iterations. ``inner_iter``: number of iterations before restarting. ``eps``: tolerance for the residual. The true tolerance is ``eps * ||b||`` or ``eps * ||r0||``. ``innerproduct``: inner product function. ``preconditioner``: preconditioner function. Should be a function that takes a vector and returns a vector. """ if hasattr(A, "__call__"): apply_A = A else: apply_A = lambda x: A @ x norm_b = np.abs(innerproduct(b, b)) ** 0.5 stopat_residual = None if norm_b > 1e-10: stopat_residual = eps * norm_b r0 = b - apply_A(x0) norm_r0 = np.abs(innerproduct(r0, r0)) ** 0.5 if norm_r0 < 1e-10 and stopat_residual is None: raise ValueError("b and A@x0 are zero (<1e-10)") if stopat_residual is None: stopat_residual = eps * norm_r0 hist = np.zeros(maxiter) iters = 0 x = x0 while iters < maxiter: niters_this = min((inner_iter, maxiter - iters)) x, info = GMRES_inner(apply_A, b, x, stopat_residual, niters_this, innerproduct, preconditioner) hist[iters: iters+niters_this] = info["history"] iters += info["k"] if verbose: print(f"GMRES: iter {iters}, res {info['res']}, target {info['target_residual']}") if info["converged"] or info["breakdown"]: break if iters >= maxiter: break info["k"] = iters info["history"] = hist[:iters] return x, info