Source code for scico.linop._linop

# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Linear operator base class."""


# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

import operator
from functools import partial, wraps
from typing import Callable, Optional, Union

import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type

import scico.numpy as snp
from scico._autograd import linear_adjoint
from scico.numpy import Array, BlockArray
from scico.numpy.util import is_complex_dtype
from scico.operator._operator import Operator, _wrap_mul_div_scalar
from scico.typing import BlockShape, DType, Shape


def _wrap_add_sub(func: Callable, op: Callable) -> Callable:
    r"""Wrapper function for defining `__add__`, `__sub__`.

    Wrapper function for defining `__add__`,` __sub__` between
    :class:`LinearOperator` and other objects.

    Handles shape checking and dispatching based on operand types:

    - If one of the two operands is an :class:`.Operator`, an
      :class:`.Operator` is returned.
    - If both operands are :class:`LinearOperator` of different types,
      a generic :class:`LinearOperator` is returned.
    - If both operands are :class:`LinearOperator` of the same type, a
      special constructor can be called

    Args:
        func: should be either `.__add__` or `.__sub__`.
        op: functional equivalent of func, ex. op.add for func =
           `__add__`.

    Raises:
        ValueError: If the shape of both operators does not match.
        TypeError: If one of the two operands is not an
            :class:`.Operator` or :class:`LinearOperator`.
    """

    @wraps(func)
    def wrapper(
        a: LinearOperator, b: Union[Operator, LinearOperator]
    ) -> Union[Operator, LinearOperator]:
        if isinstance(b, Operator):
            if a.shape == b.shape:
                if isinstance(b, type(a)):
                    # same type of linop, eg convolution can have special
                    # behavior (see Conv2d.__add__)
                    return func(a, b)
                if isinstance(
                    b, LinearOperator
                ):  # LinearOperator + LinearOperator -> LinearOperator
                    return LinearOperator(
                        input_shape=a.input_shape,
                        output_shape=a.output_shape,
                        eval_fn=lambda x: op(a(x), b(x)),
                        adj_fn=lambda x: op(a(x), b(x)),
                        input_dtype=a.input_dtype,
                        output_dtype=result_type(a.output_dtype, b.output_dtype),
                    )
                # LinearOperator + Operator -> Operator
                return Operator(
                    input_shape=a.input_shape,
                    output_shape=a.output_shape,
                    eval_fn=lambda x: op(a(x), b(x)),
                    input_dtype=a.input_dtype,
                    output_dtype=result_type(a.output_dtype, b.output_dtype),
                )
            raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.")
        raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")

    return wrapper


class LinearOperator(Operator):
    """Generic linear operator base class"""

    def __init__(
        self,
        input_shape: Union[Shape, BlockShape],
        output_shape: Optional[Union[Shape, BlockShape]] = None,
        eval_fn: Optional[Callable] = None,
        adj_fn: Optional[Callable] = None,
        input_dtype: DType = np.float32,
        output_dtype: Optional[DType] = None,
        jit: bool = False,
    ):
        r"""
        Args:
            input_shape: Shape of input array.
            output_shape: Shape of output array. Defaults to ``None``.
                If ``None``, `output_shape` is determined by evaluating
                `self.__call__` on an input array of zeros.
            eval_fn: Function used in evaluating this
                :class:`LinearOperator`. Defaults to ``None``. If
                ``None``, then `self.__call__` must be defined in any
                derived classes.
            adj_fn: Function used to evaluate the adjoint of this
                :class:`LinearOperator`. Defaults to ``None``. If
                ``None``, the adjoint is not set, and the
                :meth:`._set_adjoint` will be called silently at the
                first :meth:`.adj` call or can be called manually.
            input_dtype: `dtype` for input argument. Defaults to
                :attr:`~numpy.float32`. If the :class:`.LinearOperator`
                implements complex-valued operations, this must be a
                complex dtype (typically :attr:`~numpy.complex64`) for
                correct adjoint and gradient calculation.
            output_dtype: `dtype` for output argument. Defaults to
                ``None``. If ``None``, `output_dtype` is determined by
                evaluating `self.__call__` on an input array of zeros.
            jit: If ``True``, call :meth:`.jit()` on this
                :class:`LinearOperator` to jit the forward, adjoint, and
                gram functions. Same as calling :meth:`.jit` after the
                :class:`LinearOperator` is created.
        """

        super().__init__(
            input_shape=input_shape,
            output_shape=output_shape,
            eval_fn=eval_fn,
            input_dtype=input_dtype,
            output_dtype=output_dtype,
            jit=False,
        )

        if not hasattr(self, "_adj"):
            self._adj: Optional[Callable] = None
        if not hasattr(self, "_gram"):
            self._gram: Optional[Callable] = None
        if callable(adj_fn):
            self._adj = adj_fn
            self._gram = lambda x: self.adj(self(x))
        elif adj_fn is not None:
            raise TypeError(f"Parameter adj_fn must be either a Callable or None; got {adj_fn}.")

        if jit:
            self.jit()

    def _set_adjoint(self):
        """Automatically create adjoint method."""
        adj_fun = linear_adjoint(self.__call__, snp.zeros(self.input_shape, dtype=self.input_dtype))
        self._adj = lambda x: adj_fun(x)[0]

    def _set_gram(self):
        """Automatically create gram method."""
        self._gram = lambda x: self.adj(self(x))

[docs] def jit(self): """Replace the private functions :meth:`._eval`, :meth:`_adj`, :meth:`._gram` with jitted versions. """ if self._adj is None: self._set_adjoint() if self._gram is None: self._set_gram() self._eval = jax.jit(self._eval) self._adj = jax.jit(self._adj) self._gram = jax.jit(self._gram)
@partial(_wrap_add_sub, op=operator.add) def __add__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) + other(x), adj_fn=lambda x: self.adj(x) + other.adj(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) @partial(_wrap_add_sub, op=operator.sub) def __sub__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) - other(x), adj_fn=lambda x: self.adj(x) - other.adj(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) @_wrap_mul_div_scalar def __mul__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: other * self(x), adj_fn=lambda x: snp.conj(other) * self.adj(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) @_wrap_mul_div_scalar def __rmul__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: other * self(x), adj_fn=lambda x: snp.conj(other) * self.adj(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) @_wrap_mul_div_scalar def __truediv__(self, other): return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) / other, adj_fn=lambda x: self.adj(x) / snp.conj(other), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) def __matmul__(self, other): # self @ other return self(other) def __rmatmul__(self, other): # other @ self if isinstance(other, LinearOperator): return other(self) if isinstance(other, (np.ndarray, jnp.ndarray)): # for real valued inputs: y @ self == (self.T @ y.T).T # for complex: y @ self == (self.conj().T @ y.conj().T).conj().T # self.conj().T == self.adj return self.adj(other.conj().T).conj().T raise NotImplementedError( f"Operation __rmatmul__ not defined between {type(self)} and {type(other)}." )
[docs] def __call__( self, x: Union[LinearOperator, Array, BlockArray] ) -> Union[LinearOperator, Array, BlockArray]: r"""Evaluate this :class:`LinearOperator` at the point :math:`\mb{x}`. Args: x: Point at which to evaluate this :class:`LinearOperator`. If `x` is a :class:`jax.Array` or :class:`.BlockArray`, must have `shape == self.input_shape`. If `x` is a :class:`LinearOperator`, must have `x.output_shape == self.input_shape`. """ if isinstance(x, LinearOperator): return ComposedLinearOperator(self, x) # Use Operator __call__ for LinearOperator @ array or LinearOperator @ Operator return super().__call__(x)
[docs] def adj( self, y: Union[LinearOperator, Array, BlockArray] ) -> Union[LinearOperator, Array, BlockArray]: """Adjoint of this :class:`LinearOperator`. Compute the adjoint of this :class:`LinearOperator` applied to input `y`. Args: y: Point at which to compute adjoint. If `y` is :class:`jax.Array` or :class:`.BlockArray`, must have `shape == self.output_shape`. If `y` is a :class:`LinearOperator`, must have `y.output_shape == self.output_shape`. Returns: Adjoint evaluated at `y`. """ if self._adj is None: self._set_adjoint() if isinstance(y, LinearOperator): return ComposedLinearOperator(self.H, y) if self.output_dtype != y.dtype: raise ValueError(f"Dtype error: expected {self.output_dtype}, got {y.dtype}.") if self.output_shape != y.shape: raise ValueError( f"""Shapes do not conform: input array with shape {y.shape} does not match LinearOperator output_shape {self.output_shape}.""" ) assert self._adj is not None return self._adj(y)
@property def T(self) -> LinearOperator: """Transpose of this :class:`LinearOperator`. Return a new :class:`LinearOperator` that implements the transpose of this :class:`LinearOperator`. For a real-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.float32` or :attr:`~numpy.float64`), the :class:`LinearOperator` `A.T` implements the adjoint: `A.T(y) == A.adj(y)`. For a complex-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.complex64` or :attr:`~numpy.complex128`), the :class:`LinearOperator` `A.T` is not the adjoint. For the conjugate transpose, use `.conj().T` or :meth:`.H`. """ if is_complex_dtype(self.input_dtype): return LinearOperator( input_shape=self.output_shape, output_shape=self.input_shape, eval_fn=lambda x: self.adj(x.conj()).conj(), adj_fn=self.__call__, input_dtype=self.input_dtype, output_dtype=self.output_dtype, ) return LinearOperator( input_shape=self.output_shape, output_shape=self.input_shape, eval_fn=self.adj, adj_fn=self.__call__, input_dtype=self.output_dtype, output_dtype=self.input_dtype, ) @property def H(self) -> LinearOperator: """Hermitian transpose of this :class:`LinearOperator`. Return a new :class:`LinearOperator` that is the Hermitian transpose of this :class:`LinearOperator`. For a real-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.float32` or :attr:`~numpy.float64`), the :class:`LinearOperator` `A.H` is equivalent to `A.T`. For a complex-valued :class:`LinearOperator` `A` (`A.input_dtype` is :attr:`~numpy.complex64` or :attr:`~numpy.complex128`), the :class:`LinearOperator` `A.H` implements the adjoint of `A : A.H @ y == A.adj(y) == A.conj().T @ y)`. For the non-conjugate transpose, see :meth:`.T`. """ return LinearOperator( input_shape=self.output_shape, output_shape=self.input_shape, eval_fn=self.adj, adj_fn=self.__call__, input_dtype=self.output_dtype, output_dtype=self.input_dtype, )
[docs] def conj(self) -> LinearOperator: """Complex conjugate of this :class:`LinearOperator`. Return a new :class:`LinearOperator` `Ac` such that `Ac(x) = conj(A)(x)`. """ # A.conj() x == (A @ x.conj()).conj() return LinearOperator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x.conj()).conj(), adj_fn=lambda x: self.adj(x.conj()).conj(), input_dtype=self.input_dtype, output_dtype=self.output_dtype, )
@property def gram_op(self) -> LinearOperator: """Gram operator of this :class:`LinearOperator`. Return a new :class:`LinearOperator` `G` such that `G(x) = A.adj(A(x)))`. """ if self._gram is None: self._set_gram() return LinearOperator( input_shape=self.input_shape, output_shape=self.input_shape, eval_fn=self.gram, adj_fn=self.gram, input_dtype=self.input_dtype, output_dtype=self.output_dtype, )
[docs] def gram( self, x: Union[LinearOperator, Array, BlockArray] ) -> Union[LinearOperator, Array, BlockArray]: """Compute `A.adj(A(x)).` Args: x: Point at which to evaluate the gram operator. If `x` is a :class:`jax.Array` or :class:`.BlockArray`, must have `shape == self.input_shape`. If `x` is a :class:`LinearOperator`, must have `x.output_shape == self.input_shape`. Returns: Result of `A.adj(A(x))`. """ if self._gram is None: self._set_gram() assert self._gram is not None return self._gram(x)
class ComposedLinearOperator(LinearOperator): """A composition of two :class:`LinearOperator` objects. A new :class:`LinearOperator` formed by the composition of two other :class:`LinearOperator` objects. """ def __init__(self, A: LinearOperator, B: LinearOperator, jit: bool = False): r""" A :class:`ComposedLinearOperator` `AB` implements `AB @ x == A @ B @ x`. :class:`LinearOperator` `A` and `B` are stored as attributes of the :class:`ComposedLinearOperator`. :class:`LinearOperator` `A` and `B` must have compatible shapes and dtypes: `A.input_shape == B.output_shape` and `A.input_dtype == B.input_dtype`. Args: A: First (left) :class:`LinearOperator`. B: Second (right) :class:`LinearOperator`. jit: If ``True``, call :meth:`~.LinearOperator.jit()` on this :class:`LinearOperator` to jit the forward, adjoint, and gram functions. Same as calling :meth:`~.LinearOperator.jit` after the :class:`LinearOperator` is created. """ if not isinstance(A, LinearOperator): raise TypeError( "The first argument to ComposedLinearOperator must be a LinearOperator; " f"got {type(A)}." ) if not isinstance(B, LinearOperator): raise TypeError( "The second argument to ComposedLinearOperator must be a LinearOperator; " f"got {type(B)}." ) if A.input_shape != B.output_shape: raise ValueError(f"Incompatable LinearOperator shapes {A.shape}, {B.shape}.") if A.input_dtype != B.output_dtype: raise ValueError( f"Incompatable LinearOperator dtypes {A.input_dtype}, {B.output_dtype}." ) self.A = A self.B = B super().__init__( input_shape=self.B.input_shape, output_shape=self.A.output_shape, input_dtype=self.B.input_dtype, output_dtype=self.A.output_dtype, eval_fn=lambda x: self.A(self.B(x)), adj_fn=lambda z: self.B.adj(self.A.adj(z)), jit=jit, )