Source code for scico.function

# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Function class."""

from typing import Any, Callable, Optional, Sequence, Tuple, Union

import jax

import scico.numpy as snp
from scico.linop import LinearOperator, jacobian
from scico.numpy import Array, BlockArray
from scico.operator import Operator
from scico.typing import BlockShape, DType, Shape


[docs]class Function: r"""Function class. A :class:`Function` maps multiple :code:`array-like` arguments to another :code:`array-like`. It is more general than both :class:`.Functional`, which is a mapping to a scalar, and :class:`.Operator`, which takes a single argument. """ def __init__( self, input_shapes: Sequence[Union[Shape, BlockShape]], output_shape: Optional[Union[Shape, BlockShape]] = None, eval_fn: Optional[Callable] = None, input_dtypes: Union[DType, Sequence[DType]] = snp.float32, output_dtype: Optional[DType] = None, jit: bool = False, ): """ Args: input_shapes: Shapes of input arrays. output_shape: Shape of output array. Defaults to ``None``. If ``None``, `output_shape` is determined by evaluating `self.__call__` on input arrays of zeros. eval_fn: Function used in evaluating this :class:`Function`. Defaults to ``None``. Required unless `__init__` is being called from a derived class with an `_eval` method. input_dtypes: `dtype` for input argument. If a single `dtype` is specified, it implies a common `dtype` for all inputs, otherwise a list or tuple of values should be provided, one per input. Defaults to :attr:`~numpy.float32`. output_dtype: `dtype` for output argument. Defaults to ``None``. If ``None``, `output_dtype` is determined by evaluating `self.__call__` on an input arrays of zeros. jit: If ``True``, jit the evaluation function. """ self.jit = jit self.input_shapes = input_shapes if isinstance(input_dtypes, (list, tuple)): self.input_dtypes = input_dtypes else: self.input_dtypes = (input_dtypes,) * len(input_shapes) if eval_fn is not None: self._eval = jax.jit(eval_fn) if jit else eval_fn elif not hasattr(self, "_eval"): raise NotImplementedError( "Function is an abstract base class when the eval_fn parameter is not specified." ) # If the output shape or dtype isn't specified, it can be # inferred by calling the evaluation function. if output_shape is None or output_dtype is None: zeros = [ snp.zeros(shape, dtype=dtype) for (shape, dtype) in zip(self.input_shapes, self.input_dtypes) ] tmp = self._eval(*zeros) if output_shape is None: self.output_shape = tmp.shape # type: ignore else: self.output_shape = output_shape if output_dtype is None: self.output_dtype = tmp.dtype else: self.output_dtype = output_dtype def __repr__(self): return f"""{type(self)} input_shapes : {self.input_shapes} input_dtypes : {self.input_dtypes} output_shape : {self.output_shape} output_dtype : {self.output_dtype} """
[docs] def __call__(self, *args: Union[Array, BlockArray]) -> Union[Array, BlockArray]: """Evaluate this function with the specified parameters. Args: *args: Parameters at which to evaluate the function. Returns: Value of function with specified parameters. """ return self._eval(*args)
[docs] def slice(self, index: int, *fix_args: Union[Array, BlockArray]) -> Operator: """Fix all but one parameter, returning a :class:`.Operator`. Args: index: Index of parameter that remains free. *fix_args: Fixed values for remaining parameters. Returns: An :class:`.Operator` taking the free parameter of the :class:`Function` as its input. """ def pfunc(var_arg): args = fix_args[0:index] + (var_arg,) + fix_args[index:] return self._eval(*args) return Operator( self.input_shapes[index], output_shape=self.output_shape, eval_fn=pfunc, input_dtype=self.input_dtypes[index], output_dtype=self.output_dtype, jit=self.jit, )
[docs] def join(self) -> Operator: """Combine inputs into a :class:`.BlockArray`. Construct an equivalent :class:`.Operator` taking a single :class:`.BlockArray` input combining all inputs of this :class:`Function`. Returns: An :class:`.Operator` taking a :class:`.BlockArray` as its input. """ for dtype in self.input_dtypes[1:]: if dtype != self.input_dtypes[0]: raise ValueError( "The join method may only be applied to Functions that have " "homogenous input dtypes." ) def jfunc(blkarr): return self._eval(*blkarr.arrays) return Operator( self.input_shapes, # type: ignore output_shape=self.output_shape, eval_fn=jfunc, input_dtype=self.input_dtypes[0], output_dtype=self.output_dtype, jit=self.jit, )
[docs] def jvp( self, index: int, v: Union[Array, BlockArray], *args: Union[Array, BlockArray] ) -> Tuple[Union[Array, BlockArray], Union[Array, BlockArray]]: """Jacobian-vector product with respect to a single parameter. Compute a Jacobian-vector product with respect to a single parameter of a :class:`Function`. Note that the order of the parameters specifying where to evaluate the Jacobian and the vector in the product is reverse with respect to :func:`jax.jvp`. Args: index: Index of parameter with respect to which the Jacobian is to be computed. v: Vector against which the Jacobian-vector product is to be computed. *args: Values of function parameters at which Jacobian is to be computed. Returns: A pair consisting of the operator evaluated at the parameters specified by `*args` and the Jacobian-vector product. """ var_arg = args[index] fix_args = args[0:index] + args[(index + 1) :] F = self.slice(index, *fix_args) return F.jvp(var_arg, v)
[docs] def vjp( self, index: int, *args: Union[Array, BlockArray], conjugate: Optional[bool] = True ) -> Tuple[Tuple[Any, ...], Callable]: """Vector-Jacobian product with respect to a single parameter. Compute a vector-Jacobian product with respect to a single parameter of a :class:`Function`. Args: index: Index of parameter with respect to which the Jacobian is to be computed. *args: Values of function parameters at which Jacobian is to be computed. conjugate: If ``True``, compute the product using the conjugate (Hermitian) transpose. Returns: A pair consisting of the operator evaluated at the parameters specified by `*args` and a function that computes the vector-Jacobian product. """ var_arg = args[index] fix_args = args[0:index] + args[(index + 1) :] F = self.slice(index, *fix_args) return F.vjp(var_arg, conjugate=conjugate)
[docs] def jacobian( self, index: int, *args: Union[Array, BlockArray], include_eval: Optional[bool] = False ) -> LinearOperator: """Construct Jacobian linear operator for the function. Construct a Jacobian :class:`.LinearOperator` that computes vector products with the Jacobian with respect to a specified variable of the function. Args: index: Index of parameter with respect to which the Jacobian is to be computed. *args: Values of function parameters at which Jacobian is to be computed. include_eval: Flag indicating whether the result of evaluating the :class:`.Operator` should be included (as the first component of a :class:`.BlockArray`) in the output of the Jacobian :class:`.LinearOperator` constructed by this function. Returns: A :class:`.LinearOperator` capable of computing Jacobian-vector products. """ var_arg = args[index] fix_args = args[0:index] + args[(index + 1) :] F = self.slice(index, *fix_args) return jacobian(F, var_arg, include_eval=include_eval)