Source code for scico.operator._operator

# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Operator base class."""


# Needed to annotate a class method that returns the encapsulating class;
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

from functools import wraps
from typing import Callable, Optional, Tuple, Union

import numpy as np

import jax
import jax.numpy as jnp
from jax.dtypes import result_type

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import is_nested, shape_to_size
from scico.typing import BlockShape, DType, Shape


def _wrap_mul_div_scalar(func: Callable) -> Callable:
    r"""Wrapper function for multiplication and division operators.

    Wrapper function for defining `__mul__`, `__rmul__`, and
    `__truediv__` between a scalar and an `Operator`.

    If one of these binary operations are called in the form
    `binop(Operator, other)` and 'b' is a scalar, specialized
    :class:`.Operator` constructors can be called.

    Args:
        func: should be either `.__mul__()`, `.__rmul__()`,
           or `.__truediv__()`.

    Raises:
        TypeError: If a binop with the form `binop(Operator, other)` is
        called and `other` is not a scalar.
    """

    @wraps(func)
    def wrapper(a, b):
        if snp.util.is_scalar_equiv(b):
            return func(a, b)

        raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")

    return wrapper


class Operator:
    """Generic operator class."""

    def __repr__(self):
        return f"""{type(self)}
shape       : {self.shape}
matrix_shape : {self.matrix_shape}
input_dtype : {self.input_dtype}
output_dtype : {self.output_dtype}
        """

    # See https://numpy.org/doc/stable/user/c-info.beyond-basics.html#ndarray.__array_priority__
    __array_priority__ = 1

    def __init__(
        self,
        input_shape: Union[Shape, BlockShape],
        output_shape: Optional[Union[Shape, BlockShape]] = None,
        eval_fn: Optional[Callable] = None,
        input_dtype: DType = np.float32,
        output_dtype: Optional[DType] = None,
        jit: bool = False,
    ):
        r"""
        Args:
            input_shape: Shape of input array.
            output_shape: Shape of output array. Defaults to ``None``.
                If ``None``, `output_shape` is determined by evaluating
                `self.__call__` on an input array of zeros.
            eval_fn: Function used in evaluating this :class:`.Operator`.
                Defaults to ``None``. Required unless `__init__` is being
                called from a derived class with an `_eval` method.
            input_dtype: `dtype` for input argument. Defaults to
                :attr:`~numpy.float32`. If the :class:`.Operator`
                implements complex-valued operations, this must be a
                complex dtype (typically :attr:`~numpy.complex64`) for
                correct adjoint and gradient calculation.
            output_dtype: `dtype` for output argument. Defaults to
                ``None``. If ``None``, `output_dtype` is determined by
                evaluating `self.__call__` on an input array of zeros.
            jit: If ``True``, call :meth:`Operator.jit()` on this
                :class:`.Operator` to jit the forward, adjoint, and gram
                functions. Same as calling :meth:`Operator.jit` after the
                :class:`.Operator` is created.

        Raises:
            NotImplementedError: If the `eval_fn` parameter is not
               specified and the `_eval` method is not defined in a
               derived class.
        """

        #: Shape of input array or :class:`.BlockArray`.
        self.input_shape: Union[Shape, BlockShape]

        #: Size of flattened input. Sum of product of `input_shape` tuples.
        self.input_size: int

        #: Shape of output array or :class:`.BlockArray`
        self.output_shape: Union[Shape, BlockShape]

        #: Size of flattened output. Sum of product of `output_shape` tuples.
        self.output_size: int

        #: Shape Operator would take if it operated on flattened arrays.
        #: Consists of (output_size, input_size)
        self.matrix_shape: Tuple[int, int]

        #: Shape of Operator, consisting of (output_shape, input_shape).
        self.shape: Tuple[Union[Shape, BlockShape], Union[Shape, BlockShape]]

        #: Dtype of input
        self.input_dtype: DType

        #: Dtype of operator
        self.dtype: DType

        if isinstance(input_shape, int):
            self.input_shape = (input_shape,)
        else:
            self.input_shape = input_shape
        self.input_dtype = input_dtype

        # Allows for dynamic creation of new Operator/LinearOperator, e.g. for adjoints
        if eval_fn:
            self._eval = eval_fn  # type: ignore
        elif not hasattr(self, "_eval"):
            raise NotImplementedError(
                "Operator is an abstract base class when the eval_fn parameter is not specified."
            )

        # If the shape isn't specified by user we can infer it using by invoking the function
        if output_shape is None or output_dtype is None:
            tmp = self(snp.zeros(self.input_shape, dtype=input_dtype))
        if output_shape is None:
            self.output_shape = tmp.shape  # type: ignore
        else:
            self.output_shape = (output_shape,) if isinstance(output_shape, int) else output_shape

        if output_dtype is None:
            self.output_dtype = tmp.dtype
        else:
            self.output_dtype = output_dtype

        # Determine the shape of the "vectorized" operator (as an element of ℝ^{n × m}
        # If the function returns a BlockArray we need to compute the size of each block,
        # then sum.
        self.input_size = shape_to_size(self.input_shape)
        self.output_size = shape_to_size(self.output_shape)

        self.shape = (self.output_shape, self.input_shape)
        self.matrix_shape = (self.output_size, self.input_size)

        if jit:
            self.jit()

[docs] def jit(self): """Activate just-in-time compilation for the `_eval` method.""" self._eval = jax.jit(self._eval)
[docs] def __call__(self, x: Union[Operator, Array, BlockArray]) -> Union[Operator, Array, BlockArray]: r"""Evaluate this :class:`Operator` at the point :math:`\mb{x}`. Args: x: Point at which to evaluate this :class:`.Operator`. If `x` is a :class:`jax.Array` or :class:`.BlockArray`, it must have `shape == self.input_shape`. If `x` is a :class:`.Operator` or :class:`.LinearOperator`, it must have `x.output_shape == self.input_shape`. Returns: :class:`.Operator` evaluated at `x`. Raises: ValueError: If the `input_shape` attribute of the :class:`.Operator` is not equal to the input array shape, or to the `output_shape` attribute of another :class:`.Operator` with which it is composed. """ if isinstance(x, Operator): # Compose the two operators if shapes conform if self.input_shape == x.output_shape: return Operator( input_shape=x.input_shape, output_shape=self.output_shape, eval_fn=lambda z: self(x(z)), input_dtype=self.input_dtype, output_dtype=x.output_dtype, ) raise ValueError(f"Incompatible shapes {self.shape}, {x.shape}.") if self.input_shape != x.shape: raise ValueError( f"Cannot evaluate {type(self)} with input_shape={self.input_shape} " f"on array with shape={x.shape}." ) return self._eval(x)
def __add__(self, other: Operator) -> Operator: if isinstance(other, Operator): if self.shape == other.shape: return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) + other(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") raise TypeError(f"Operation __add__ not defined between {type(self)} and {type(other)}.") def __sub__(self, other: Operator) -> Operator: if isinstance(other, Operator): if self.shape == other.shape: return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) - other(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other.output_dtype), ) raise ValueError(f"Shapes {self.shape} and {other.shape} do not match.") raise TypeError(f"Operation __sub__ not defined between {type(self)} and {type(other)}.") @_wrap_mul_div_scalar def __mul__(self, other): return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: other * self(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) def __neg__(self) -> Operator: return -1.0 * self @_wrap_mul_div_scalar def __rmul__(self, other): return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: other * self(x), input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), ) @_wrap_mul_div_scalar def __truediv__(self, other): return Operator( input_shape=self.input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(x) / other, input_dtype=self.input_dtype, output_dtype=result_type(self.output_dtype, other), )
[docs] def jvp(self, u, v): r"""Compute a Jacobian-vector product. Compute the product :math:`J_F(\mb{u}) \mb{v}` where :math:`F` represents this operator and :math:`J_F(\mb{u})` is the Jacobian of :math:`F` evaluated at :math:`\mb{u}`. This method is implemented via a call to :func:`jax.jvp`. Args: u: Value at which the Jacobian is evaluated. v: Vector in the Jacobian-vector product. Returns: A pair :math:`(F(\mb{u}), J_F(\mb{u}) \mb{v})`, i.e. a pair consisting of the operator evaluated at :math:`\mb{u}` and the Jacobian-vector product. """ return jax.jvp(self, (u,), (v,))
[docs] def vjp(self, u, conjugate=True): r"""Compute a vector-Jacobian product. Compute the product :math:`[J_F(\mb{u})]^T \mb{v}` where :math:`F` represents this operator and :math:`J_F(\mb{u})` is the Jacobian of :math:`F` evaluated at :math:`\mb{u}`. Instead of directly computing the vector-Jacobian product, this method returns a function, taking :math:`\mb{v}` as an argument, that returns the product. This method is implemented via a call to :func:`jax.vjp`. Args: u: Value at which the Jacobian is evaluated. conjugate: If ``True``, compute the product using the conjugate (Hermitian) transpose. Returns: A pair :math:`(F(\mb{u}), G(\cdot))` where :math:`G(\cdot)` is a function that computes the vector-Jacobian product, i.e. :math:`G(\mb{v}) = [J_F(\mb{u})]^T \mb{v}` when `conjugate` is ``False``, or :math:`G(\mb{v}) = [J_F(\mb{u})]^H \mb{v}` when `conjugate` is ``True``. """ Fu, G = jax.vjp(self, u) if conjugate: def Gmap(v): return G(v.conj())[0].conj() else: def Gmap(v): return G(v)[0] return Fu, Gmap
[docs] def freeze(self, argnum: int, val: Union[Array, BlockArray]) -> Operator: """Return a new :class:`.Operator` with fixed block argument. Return a new :class:`.Operator` with block argument `argnum` fixed to value `val`. Args: argnum: Index of block to freeze. Must be less than or equal to the number of blocks in an input array. val: Value to fix the `argnum`-th input to. Returns: A new :class:`.Operator` with one of the blocks of the input fixed to the specified value. Raises: ValueError: If the :class:`.Operator` does not take a :class:`.BlockArray` as its input, if the block index equals or exceeds the number of blocks, or if the shape of the fixed value differs from the shape of the specified block. """ if not is_nested(self.input_shape): raise ValueError( "The freeze method can only be applied to Operators that take BlockArray inputs." ) input_ndim = len(self.input_shape) if argnum > input_ndim - 1: raise ValueError( f"Parameter argnum to freeze must be less than the number of input arguments to " f"this operator ({input_ndim}); got {argnum}." ) if val.shape != self.input_shape[argnum]: raise ValueError( f"Value to be frozen at position {argnum} must have shape " f"{self.input_shape[argnum]}, got {val.shape}." ) input_shape: Union[Shape, BlockShape] input_shape = tuple(s for i, s in enumerate(self.input_shape) if i != argnum) # type: ignore if len(input_shape) == 1: input_shape = input_shape[0] # type: ignore def concat_args(args): # Create a blockarray with args and the frozen value in the correct place # E.g. if this operator takes a blockarray with two blocks, then # concat_args(args) = snp.blockarray([val, args]) if argnum = 0 # concat_args(args) = snp.blockarray([args, val]) if argnum = 1 if isinstance(args, (jnp.ndarray, np.ndarray)): # In the case that the original operator takes a blockarray with two # blocks, wrap in a list so we can use the same indexing as >2 block case args = [args] arg_list = [] for i in range(input_ndim): if i < argnum: arg_list.append(args[i]) elif i > argnum: arg_list.append(args[i - 1]) else: arg_list.append(val) return snp.blockarray(arg_list) return Operator( input_shape=input_shape, output_shape=self.output_shape, eval_fn=lambda x: self(concat_args(x)), )