Source code for scico.solver

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Solver and optimization algorithms.

This module provides a number of functions for solving linear systems and
optimization problems, some of which are used as subproblem solvers
within the iterations of the proximal algorithms in the
:mod:`scico.optimize` subpackage.

This module also provides scico interface wrappers for functions
from :mod:`scipy.optimize` since jax directly implements only a very
limited subset of these functions (there is limited, experimental support
for `L-BFGS-B <https://github.com/google/jax/pull/6053>`_), but only CG
and BFGS are fully supported. These wrappers are required because the
functions in :mod:`scipy.optimize` only support on 1D, real valued, numpy
arrays. These limitations are addressed by:

- Enabling the use of multi-dimensional arrays by flattening and reshaping
  within the wrapper.
- Enabling the use of jax arrays by automatically converting to and from
  numpy arrays.
- Enabling the use of complex arrays by splitting them into real and
  imaginary parts.

The wrapper also JIT compiles the function and gradient evaluations.

These wrapper functions have a number of advantages and disadvantages
with respect to those in :mod:`jax.scipy.optimize`:

- This module provides many more algorithms than
  :mod:`jax.scipy.optimize`.
- The functions in this module tend to be faster for small-scale problems
  (presumably due to some overhead in the jax functions).
- The functions in this module are slower for large problems due to the
  frequent host-device copies corresponding to conversion between numpy
  arrays and jax arrays.
- The solvers in this module can't be JIT compiled, and gradients cannot
  be taken through them.

In the future, these wrapper functions may be replaced with a dependency on
`JAXopt <https://github.com/google/jaxopt>`__.
"""


from functools import wraps
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import numpy as np

import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsl

import scico.numpy as snp
from scico.linop import (
    CircularConvolve,
    ComposedLinearOperator,
    Diagonal,
    LinearOperator,
    MatrixOperator,
    Sum,
)
from scico.metric import rel_res
from scico.numpy import Array, BlockArray
from scico.numpy.util import is_real_dtype
from scico.typing import BlockShape, DType, Shape
from scipy import optimize as spopt


def _wrap_func(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable:
    """Function evaluation for use in :mod:`scipy.optimize`.

    Compute function evaluation (without gradient) for use in
    :mod:`scipy.optimize` functions. Reshapes the input to `func` to
    have `shape`. Evaluates `func`.

    Args:
        func: The function to minimize.
        shape: Shape of input to `func`.
        dtype: Data type of input to `func`.
    """

    val_func = jax.jit(func)

    @wraps(func)
    def wrapper(x, *args):
        # apply val_grad_func to un-vectorized input
        val = val_func(snp.reshape(x, shape).astype(dtype), *args)

        # Convert val into numpy array, cast to float, convert to scalar
        val = np.array(val).astype(float)
        val = val.item() if val.ndim == 0 else val[0].item()

        return val

    return wrapper


def _wrap_func_and_grad(func: Callable, shape: Union[Shape, BlockShape], dtype: DType) -> Callable:
    """Function evaluation and gradient for use in :mod:`scipy.optimize`.

    Compute function evaluation and gradient for use in
    :mod:`scipy.optimize` functions. Reshapes the input to `func` to
    have `shape`.  Evaluates `func` and computes gradient. Ensures
    the returned `grad` is an ndarray.

    Args:
        func: The function to minimize.
        shape: Shape of input to `func`.
        dtype: Data type of input to `func`.
    """

    # argnums=0 ensures only differentiate func wrt first argument,
    #   in case func signature is func(x, *args)
    val_grad_func = jax.jit(jax.value_and_grad(func, argnums=0))

    @wraps(func)
    def wrapper(x, *args):
        # apply val_grad_func to un-vectorized input
        val, grad = val_grad_func(snp.reshape(x, shape).astype(dtype), *args)

        # Convert val & grad into numpy arrays, then cast to float
        # Convert 'val' into a scalar, rather than ndarray of shape (1,)
        val = np.array(val).astype(float).item()
        grad = np.array(grad).astype(float).ravel()
        return val, grad

    return wrapper


def _split_real_imag(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
    """Split an array of shape (N, M, ...) into real and imaginary parts.

    Args:
        x: Array to split.

    Returns:
        A real ndarray with stacked real/imaginary parts. If `x` has
        shape (M, N, ...), the returned array will have shape
        (2, M, N, ...) where the first slice contains the `x.real` and
        the second contains `x.imag`. If `x` is a BlockArray, this
        function is called on each block and the output is joined into a
        BlockArray.
    """
    if isinstance(x, BlockArray):
        return snp.blockarray([_split_real_imag(_) for _ in x])
    return snp.stack((snp.real(x), snp.imag(x)))


def _join_real_imag(x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
    """Join a real array of shape (2,N,M,...) into a complex array.

    Join a real array of shape (2,N,M,...) into a complex array of length
    (N,M, ...).

    Args:
        x: Array to join.

    Returns:
        A complex array with real and imaginary parts taken from `x[0]`
        and `x[1]` respectively.
    """
    if isinstance(x, BlockArray):
        return snp.blockarray([_join_real_imag(_) for _ in x])
    return x[0] + 1j * x[1]


[docs]def minimize( func: Callable, x0: Union[Array, BlockArray], args: Union[Tuple, Tuple[Any]] = (), method: str = "L-BFGS-B", hess: Optional[Union[Callable, str]] = None, hessp: Optional[Callable] = None, bounds: Optional[Union[Sequence, spopt.Bounds]] = None, constraints: Union[spopt.LinearConstraint, spopt.NonlinearConstraint, dict] = (), tol: Optional[float] = None, callback: Optional[Callable] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: """Minimization of scalar function of one or more variables. Wrapper around :func:`scipy.optimize.minimize`. This function differs from :func:`scipy.optimize.minimize` in three ways: - The `jac` options of :func:`scipy.optimize.minimize` are not supported. The gradient is calculated using `jax.grad`. - Functions mapping from N-dimensional arrays -> float are supported. - Functions mapping from complex arrays -> float are supported. For more detail, including descriptions of the optimization methods and custom minimizers, refer to the original docs for :func:`scipy.optimize.minimize`. """ if snp.util.is_complex_dtype(x0.dtype): # scipy minimize function requires real-valued arrays, so # we split x0 into a vector with real/imaginary parts stacked # and compose `func` with a `_join_real_imag` iscomplex = True func_ = lambda x: func(_join_real_imag(x)) x0 = _split_real_imag(x0) else: iscomplex = False func_ = func x0_shape = x0.shape x0_dtype = x0.dtype x0 = x0.ravel() # if x0 is a BlockArray it will become a jax array here # Run the SciPy minimizer if method in ( "CG, BFGS, Newton-CG, L-BFGS-B, TNC, SLSQP, dogleg, trust-ncg, trust-krylov, " "trust-exact, trust-constr" ).split( ", " ): # uses gradient info min_func = _wrap_func_and_grad(func_, x0_shape, x0_dtype) jac = True # see scipy.minimize docs else: # does not use gradient info min_func = _wrap_func(func_, x0_shape, x0_dtype) jac = False res = spopt.OptimizeResult({"x": None}) def fun(x0): nonlocal res # To use the external res and update side effect res = spopt.minimize( min_func, x0=x0, args=args, jac=jac, method=method, options=options, ) # Return OptimizeResult with x0 as ndarray return res.x.astype(x0_dtype) # callback with side effects to get the OptimizeResult on the same device it was called res.x = jax.pure_callback( fun, jax.ShapeDtypeStruct(x0.shape, x0_dtype), x0, ) # un-vectorize the output array from spopt.minimize res.x = snp.reshape( res.x, x0_shape ) # if x0 was originally a BlockArray then res.x is converted back to one here if iscomplex: res.x = _join_real_imag(res.x) return res
[docs]def minimize_scalar( func: Callable, bracket: Optional[Sequence[float]] = None, bounds: Optional[Sequence[float]] = None, args: Union[Tuple, Tuple[Any]] = (), method: str = "brent", tol: Optional[float] = None, options: Optional[dict] = None, ) -> spopt.OptimizeResult: """Minimization of scalar function of one variable. Wrapper around :func:`scipy.optimize.minimize_scalar`. For more detail, including descriptions of the optimization methods and custom minimizers, refer to the original docstring for :func:`scipy.optimize.minimize_scalar`. """ def f(x, *args): # Wrap jax-based function `func` to return a numpy float rather # than a jax array of size (1,) y = func(x, *args) return y.item() if y.ndim == 0 else y[0].item() res = spopt.minimize_scalar( fun=f, bracket=bracket, bounds=bounds, args=args, method=method, tol=tol, options=options, ) return res
[docs]def cg( A: Callable, b: Array, x0: Optional[Array] = None, *, tol: float = 1e-5, atol: float = 0.0, maxiter: int = 1000, info: bool = True, M: Optional[Callable] = None, ) -> Tuple[Array, dict]: r"""Conjugate Gradient solver. Solve the linear system :math:`A\mb{x} = \mb{b}`, where :math:`A` is positive definite, via the conjugate gradient method. Args: A: Callable implementing linear operator :math:`A`, which should be positive definite. b: Input array :math:`\mb{b}`. x0: Initial solution. If `A` is a :class:`.LinearOperator`, this parameter need not be specified, and defaults to a zero array. Otherwise, it is required. tol: Relative residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. atol: Absolute residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. maxiter: Maximum iterations. Default: 1000. info: If ``True`` return a tuple consting of the solution array and a dictionary containing diagnostic information, otherwise just return the solution. M: Preconditioner for `A`. The preconditioner should approximate the inverse of `A`. The default, ``None``, uses no preconditioner. Returns: tuple: A tuple (x, info) containing: - **x** : Solution array. - **info**: Dictionary containing diagnostic information. """ if x0 is None: if isinstance(A, LinearOperator): x0 = snp.zeros(A.input_shape, b.dtype) else: raise ValueError("Parameter x0 must be specified if A is not a LinearOperator") if M is None: M = lambda x: x x = x0 Ax = A(x0) bn = snp.linalg.norm(b) r = b - Ax z = M(r) p = z num = snp.sum(r.conj() * z) ii = 0 # termination tolerance (uses the "non-legacy" form of scicpy.sparse.linalg.cg) termination_tol_sq = snp.maximum(tol * bn, atol) ** 2 while (ii < maxiter) and (num > termination_tol_sq): Ap = A(p) alpha = num / snp.sum(p.conj() * Ap) x = x + alpha * p r = r - alpha * Ap z = M(r) num_old = num num = snp.sum(r.conj() * z) beta = num / num_old p = z + beta * p ii += 1 if info: return (x, {"num_iter": ii, "rel_res": snp.sqrt(num).real / bn}) else: return x
[docs]def lstsq( A: Callable, b: Array, x0: Optional[Array] = None, tol: float = 1e-5, atol: float = 0.0, maxiter: int = 1000, info: bool = False, M: Optional[Callable] = None, ) -> Tuple[Array, dict]: r"""Least squares solver. Solve the least squares problem .. math:: \argmin_{\mb{x}} \; (1/2) \norm{ A \mb{x} - \mb{b}) }_2^2 \;, where :math:`A` is a linear operator and :math:`\mb{b}` is a vector. The problem is solved using :func:`cg`. Args: A: Callable implementing linear operator :math:`A`. b: Input array :math:`\mb{b}`. x0: Initial solution. If `A` is a :class:`.LinearOperator`, this parameter need not be specified, and defaults to a zero array. Otherwise, it is required. tol: Relative residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. atol: Absolute residual stopping tolerance. Convergence occurs when `norm(residual) <= max(tol * norm(b), atol)`. maxiter: Maximum iterations. Default: 1000. info: If ``True`` return a tuple consting of the solution array and a dictionary containing diagnostic information, otherwise just return the solution. M: Preconditioner for `A`. The preconditioner should approximate the inverse of `A`. The default, ``None``, uses no preconditioner. Returns: tuple: A tuple (x, info) containing: - **x** : Solution array. - **info**: Dictionary containing diagnostic information. """ if isinstance(A, LinearOperator): Aop = A else: assert x0 is not None Aop = LinearOperator( input_shape=x0.shape, output_shape=b.shape, eval_fn=A, input_dtype=b.dtype, output_dtype=b.dtype, ) ATA = Aop.T @ Aop ATb = Aop.T @ b return cg(ATA, ATb, x0=x0, tol=tol, atol=atol, maxiter=maxiter, info=info, M=M)
[docs]def bisect( f: Callable, a: Array, b: Array, args: Tuple = (), xtol: float = 1e-7, ftol: float = 1e-7, maxiter: int = 100, full_output: bool = False, range_check: bool = True, ) -> Union[Array, dict]: """Vectorised root finding via bisection method. Vectorised root finding via bisection method, supporting simultaneous finding of multiple roots on a function defined over a multi-dimensional array. When the function is array-valued, each of these values is treated as the independent application of a scalar function. The initial interval `[a, b]` must bracket the root for all scalar functions. The interface is similar to that of :func:`scipy.optimize.bisect`, which is much faster when `f` is a scalar function and `a` and `b` are scalars. Args: f: Function returning a float or an array of floats. a: Lower bound of interval on which to apply bisection. b: Upper bound of interval on which to apply bisection. args: Additional arguments for function `f`. xtol: Stopping tolerance based on maximum bisection interval length over array. ftol: Stopping tolerance based on maximum absolute function value over array. maxiter: Maximum number of algorithm iterations. full_output: If ``False``, return just the root, otherwise return a tuple `(x, info)` where `x` is the root and `info` is a dict containing algorithm status information. range_check: If ``True``, check to ensure that the initial `[a, b]` range brackets the root of `f`. Returns: tuple: A tuple `(x, info)` containing: - **x** : Root array. - **info**: Dictionary containing diagnostic information. """ fa = f(*((a,) + args)) fb = f(*((b,) + args)) if range_check and snp.any(snp.sign(fa) == snp.sign(fb)): raise ValueError("Initial bisection range does not bracket zero.") for numiter in range(maxiter): c = (a + b) / 2.0 fc = f(*((c,) + args)) fcs = snp.sign(fc) a = snp.where(snp.logical_or(snp.sign(fa) * fcs == 1, fc == 0.0), c, a) b = snp.where(snp.logical_or(fcs * snp.sign(fb) == 1, fc == 0.0), c, b) fa = f(*((a,) + args)) fb = f(*((b,) + args)) xerr = snp.max(snp.abs(b - a)) ferr = snp.max(snp.abs(fc)) if xerr <= xtol and ferr <= ftol: break idx = snp.argmin(snp.stack((snp.abs(fa), snp.abs(fb))), axis=0) x = snp.choose(idx, (a, b)) if full_output: r = x, {"iter": numiter, "xerr": xerr, "ferr": ferr, "a": a, "b": b} else: r = x return r
[docs]def golden( f: Callable, a: Array, b: Array, c: Optional[Array] = None, args: Tuple = (), xtol: float = 1e-7, maxiter: int = 100, full_output: bool = False, ) -> Union[Array, dict]: """Vectorised scalar minimization via golden section method. Vectorised scalar minimization via golden section method, supporting simultaneous minimization of a function defined over a multi-dimensional array. When the function is array-valued, each of these values is treated as the independent application of a scalar function. The minimizer must lie within the interval `(a, b)` for all scalar functions, and, if specified `c` must be within that interval. The interface is more similar to that of :func:`.bisect` than that of :func:`scipy.optimize.golden` which is much faster when `f` is a scalar function and `a`, `b`, and `c` are scalars. Args: f: Function returning a float or an array of floats. a: Lower bound of interval on which to search. b: Upper bound of interval on which to search. c: Initial value for first search point interior to bounding interval `(a, b)` args: Additional arguments for function `f`. xtol: Stopping tolerance based on maximum search interval length over array. maxiter: Maximum number of algorithm iterations. full_output: If ``False``, return just the minizer, otherwise return a tuple `(x, info)` where `x` is the minimizer and `info` is a dict containing algorithm status information. Returns: tuple: A tuple `(x, info)` containing: - **x** : Minimizer array. - **info**: Dictionary containing diagnostic information. """ gr = 2 / (snp.sqrt(5) + 1) if c is None: c = b - gr * (b - a) d = a + gr * (b - a) for numiter in range(maxiter): fc = f(*((c,) + args)) fd = f(*((d,) + args)) b = snp.where(fc < fd, d, b) a = snp.where(fc >= fd, c, a) xerr = snp.amax(snp.abs(b - a)) if xerr <= xtol: break c = b - gr * (b - a) d = a + gr * (b - a) fa = f(*((a,) + args)) fb = f(*((b,) + args)) idx = snp.argmin(snp.stack((fa, fb)), axis=0) x = snp.choose(idx, (a, b)) if full_output: r = (x, {"iter": numiter, "xerr": xerr}) else: r = x return r
[docs]class MatrixATADSolver: r"""Solver for linear system involving a symmetric product. Solve a linear system of the form .. math:: (A^T W A + D) \mb{x} = \mb{b} or .. math:: (A^T W A + D) X = B \;, where :math:`A \in \mbb{R}^{M \times N}`, :math:`W \in \mbb{R}^{M \times M}` and :math:`D \in \mbb{R}^{N \times N}`. :math:`A` must be an instance of :class:`.MatrixOperator` or an array; :math:`D` must be an instance of :class:`.MatrixOperator`, :class:`.Diagonal`, or an array, and :math:`W`, if specified, must be an instance of :class:`.Diagonal` or an array. The solution is computed by factorization of matrix :math:`A^T W A + D` and solution via Gaussian elimination. If :math:`D` is diagonal and :math:`N < M` (i.e. :math:`A W A^T` is smaller than :math:`A^T W A`), then :math:`A W A^T + D` is factorized and the original problem is solved via the Woodbury matrix identity .. math:: (E + U C V)^{-1} = E^{-1} - E^{-1} U (C^{-1} + V E^{-1} U)^{-1} V E^{-1} \;. Setting .. math:: E &= D \\ U &= A^T \\ C &= W \\ V &= A we have .. math:: (D + A^T W A)^{-1} = D^{-1} - D^{-1} A^T (W^{-1} + A D^{-1} A^T)^{-1} A D^{-1} which can be simplified to .. math:: (D + A^T W A)^{-1} = D^{-1} (I - A^T G^{-1} A D^{-1}) by defining :math:`G = W^{-1} + A D^{-1} A^T`. We therefore have that .. math:: \mb{x} = (D + A^T W A)^{-1} \mb{b} = D^{-1} (I - A^T G^{-1} A D^{-1}) \mb{b} \;. If we have a Cholesky factorization of :math:`G`, e.g. :math:`G = L L^T`, we can define .. math:: \mb{w} = G^{-1} A D^{-1} \mb{b} so that .. math:: G \mb{w} &= A D^{-1} \mb{b} \\ L L^T \mb{w} &= A D^{-1} \mb{b} \;. The Cholesky factorization can be exploited by solving for :math:`\mb{z}` in .. math:: L \mb{z} = A D^{-1} \mb{b} and then for :math:`\mb{w}` in .. math:: L^T \mb{w} = \mb{z} \;, so that .. math:: \mb{x} = D^{-1} \mb{b} - D^{-1} A^T \mb{w} \;. (Functions :func:`~jax.scipy.linalg.cho_solve` and :func:`~jax.scipy.linalg.lu_solve` allow direct solution for :math:`\mb{w}` without the two-step procedure described here.) A Cholesky factorization should only be used when :math:`G` is positive-definite (e.g. :math:`D` is diagonal and positive); if not, an LU factorization should be used. Complex-valued problems are also supported, in which case the transpose :math:`\cdot^T` in the equations above should be taken to represent the conjugate transpose. To solve problems directly involving a matrix of the form :math:`A W A^T + D`, initialize with :code:`A.T` (or :code:`A.T.conj()` for complex problems) instead of :code:`A`. """ def __init__( self, A: Union[MatrixOperator, Array], D: Union[MatrixOperator, Diagonal, Array], W: Optional[Union[Diagonal, Array]] = None, cho_factor: bool = False, lower: bool = False, check_finite: bool = True, ): r""" Args: A: Matrix :math:`A`. D: Matrix :math:`D`. If a 2D array or :class:`MatrixOperator`, specifies the 2D matrix :math:`D`. If 1D array or :class:`Diagonal`, specifies the diagonal elements of :math:`D`. W: Matrix :math:`W`. Specifies the diagonal elements of :math:`W`. Defaults to an array with unit entries. cho_factor: Flag indicating whether to use Cholesky (``True``) or LU (``False``) factorization. lower: Flag indicating whether lower (``True``) or upper (``False``) triangular factorization should be computed. Only relevant to Cholesky factorization. check_finite: Flag indicating whether the input array should be checked for ``Inf`` and ``NaN`` values. """ A = jnp.array(A) if isinstance(D, Diagonal): D = D.diagonal if not D.ndim == 1: raise ValueError("If Diagonal, D should have a 1D diagonal.") else: D = jnp.array(D) if not D.ndim in [1, 2]: raise ValueError("If array or MatrixOperator, D should be 1D or 2D.") if W is None: W = snp.ones(A.shape[0], dtype=A.dtype) elif isinstance(W, Diagonal): W = W.diagonal if not W.ndim == 1: raise ValueError("If Diagonal, W should have a 1D diagonal.") elif not isinstance(W, Array): raise TypeError( f"Operator W is required to be None, a Diagonal, or an array; got a {type(W)}." ) self.A = A self.D = D self.W = W self.cho_factor = cho_factor self.lower = lower self.check_finite = check_finite assert isinstance(W, Array) N, M = A.shape if N < M and D.ndim == 1: G = snp.diag(1.0 / W) + A @ (A.T.conj() / D[:, snp.newaxis]) else: if D.ndim == 1: G = A.T.conj() @ (W[:, snp.newaxis] * A) + snp.diag(D) else: G = A.T.conj() @ (W[:, snp.newaxis] * A) + D if cho_factor: c, lower = jsl.cho_factor(G, lower=lower, check_finite=check_finite) self.factor = (c, lower) else: lu, piv = jsl.lu_factor(G, check_finite=check_finite) self.factor = (lu, piv)
[docs] def solve(self, b: Array, check_finite: Optional[bool] = None) -> Array: r"""Solve the linear system. Solve the linear system with right hand side :math:`\mb{b}` (`b` is a vector) or :math:`B` (`b` is a 2d array). Args: b: Vector :math:`\mathbf{b}` or matrix :math:`B`. check_finite: Flag indicating whether the input array should be checked for ``Inf`` and ``NaN`` values. If ``None``, use the value selected on initialization. Returns: Solution to the linear system. """ if check_finite is None: check_finite = self.check_finite if self.cho_factor: fact_solve = lambda x: jsl.cho_solve(self.factor, x, check_finite=check_finite) else: fact_solve = lambda x: jsl.lu_solve(self.factor, x, trans=0, check_finite=check_finite) if b.ndim == 1: D = self.D else: D = self.D[:, snp.newaxis] N, M = self.A.shape if N < M and self.D.ndim == 1: w = fact_solve(self.A @ (b / D)) x = (b - (self.A.T.conj() @ w)) / D else: x = fact_solve(b) return x
[docs] def accuracy(self, x: Array, b: Array) -> float: r"""Compute solution relative residual. Args: x: Array :math:`\mathbf{x}` (solution). b: Array :math:`\mathbf{b}` (right hand side of linear system). Returns: Relative residual of solution. """ if b.ndim == 1: D = self.D else: D = self.D[:, snp.newaxis] assert isinstance(self.W, Array) return rel_res(self.A.T.conj() @ (self.W[:, snp.newaxis] * self.A) @ x + D * x, b)
[docs]class ConvATADSolver: r"""Solver for a linear system involving a sum of convolutions. Solve a linear system of the form .. math:: (A^H A + D) \mb{x} = \mb{b} where :math:`A` is a block-row operator with circulant blocks, i.e. it can be written as .. math:: A = \left( \begin{array}{cccc} A_1 & A_2 & \ldots & A_{K} \end{array} \right) \;, where all of the :math:`A_k` are circular convolution operators, and :math:`D` is a circular convolution operator. This problem is most easily solved in the DFT transform domain, where the circular convolutions become diagonal operators. Denoting the frequency-domain versions of variables with a circumflex (e.g. :math:`\hat{\mb{x}}` is the frequency-domain version of :math:`\mb{x}`), the the problem can be written as .. math:: (\hat{A}^H \hat{A} + \hat{D}) \hat{\mb{x}} = \hat{\mb{b}} \;, where .. math:: \hat{A} = \left( \begin{array}{cccc} \hat{A}_1 & \hat{A}_2 & \ldots & \hat{A}_{K} \end{array} \right) \;, and :math:`\hat{D}` and all the :math:`\hat{A}_k` are diagonal operators. This linear equation is computational expensive to solve because the left hand side includes the term :math:`\hat{A}^H \hat{A}`, which corresponds to the outer product of :math:`\hat{A}^H` and :math:`\hat{A}`. A computationally efficient solution is possible, however, by exploiting the Woodbury matrix identity :cite:`wohlberg-2014-efficient` .. math:: (B + U C V)^{-1} = B^{-1} - B^{-1} U (C^{-1} + V B^{-1} U)^{-1} V B^{-1} \;. Setting .. math:: B &= \hat{D} \\ U &= \hat{A}^H \\ C &= I \\ V &= \hat{A} we have .. math:: (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} - \hat{D}^{-1} \hat{A}^H (I + \hat{A} \hat{D}^{-1} \hat{A}^H)^{-1} \hat{A} \hat{D}^{-1} which can be simplified to .. math:: (\hat{D} + \hat{A}^H \hat{A})^{-1} = \hat{D}^{-1} (I - \hat{A}^H \hat{E}^{-1} \hat{A} \hat{D}^{-1}) by defining :math:`\hat{E} = I + \hat{A} \hat{D}^{-1} \hat{A}^H`. The right hand side is much cheaper to compute because the only matrix inversions involve :math:`\hat{D}`, which is diagonal, and :math:`\hat{E}`, which is a weighted inner product of :math:`\hat{A}^H` and :math:`\hat{A}`. """ def __init__(self, A: ComposedLinearOperator, D: CircularConvolve): r""" Args: A: Operator :math:`A`. D: Operator :math:`D`. """ if not isinstance(A, ComposedLinearOperator): raise TypeError( f"Operator A is required to be a ComposedLinearOperator; got a {type(A)}." ) if not isinstance(A.A, Sum) or not isinstance(A.B, CircularConvolve): raise TypeError( "Operator A is required to be a composition of Sum and CircularConvolve" f"linear operators; got a composition of {type(A.A)} and {type(A.B)}." ) self.A = A self.D = D self.sum_axis = A.A.kwargs["axis"] if not isinstance(self.sum_axis, int): raise ValueError( "Sum component of operator A must sum over a single axis of its input." ) self.fft_axes = A.B.x_fft_axes self.real_result = is_real_dtype(D.input_dtype) Ahat = A.B.h_dft Dhat = D.h_dft self.AHEinv = Ahat.conj() / ( 1.0 + snp.sum(Ahat * (Ahat.conj() / Dhat), axis=self.sum_axis, keepdims=True) )
[docs] def solve(self, b: Array) -> Array: r"""Solve the linear system. Solve the linear system with right hand side :math:`\mb{b}`. Args: b: Array :math:`\mathbf{b}`. Returns: Solution to the linear system. """ assert isinstance(self.A.B, CircularConvolve) Ahat = self.A.B.h_dft Dhat = self.D.h_dft bhat = snp.fft.fftn(b, axes=self.fft_axes) xhat = ( bhat - (self.AHEinv * (snp.sum(Ahat * bhat / Dhat, axis=self.sum_axis, keepdims=True))) ) / Dhat x = snp.fft.ifftn(xhat, axes=self.fft_axes) if self.real_result: x = x.real return x
[docs] def accuracy(self, x: Array, b: Array) -> float: r"""Compute solution relative residual. Args: x: Array :math:`\mathbf{x}` (solution). b: Array :math:`\mathbf{b}` (right hand side of linear system). Returns: Relative residual of solution. """ return rel_res(self.A.gram_op(x) + self.D(x), b)