scico.operator#

Operator functions and classes.

Functions

operator_from_function(f, classname[, f_name])

Make an Operator from a function.

Classes

Abs(input_shape, *args[, input_dtype, ...])

Operator version of scico.numpy.abs.

Angle(input_shape, *args[, input_dtype, ...])

Operator version of scico.numpy.angle.

BiConvolve(input_shape[, input_dtype, mode, jit])

Biconvolution operator.

DiagonalStack(ops[, collapse_input, ...])

A diagonal stack of operators.

Exp(input_shape, *args[, input_dtype, ...])

Operator version of scico.numpy.exp.

Operator(input_shape[, output_shape, ...])

Generic operator class.

VerticalStack(ops[, collapse_output, jit])

A vertical stack of operators.

class scico.operator.Operator(input_shape, output_shape=None, eval_fn=None, input_dtype=<class 'numpy.float32'>, output_dtype=None, jit=False)#

Bases: object

Generic operator class.

Parameters:
  • input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]]) – Shape of input array.

  • output_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...], None]) – Shape of output array. Defaults to None. If None, output_shape is determined by evaluating self.__call__ on an input array of zeros.

  • eval_fn (Optional[Callable]) – Function used in evaluating this Operator. Defaults to None. Required unless __init__ is being called from a derived class with an _eval method.

  • input_dtype (DType) – dtype for input argument. Defaults to float32. If the Operator implements complex-valued operations, this must be a complex dtype (typically complex64) for correct adjoint and gradient calculation.

  • output_dtype (Optional[DType]) – dtype for output argument. Defaults to None. If None, output_dtype is determined by evaluating self.__call__ on an input array of zeros.

  • jit (bool) – If True, call Operator.jit on this Operator to jit the forward, adjoint, and gram functions. Same as calling Operator.jit after the Operator is created.

Raises:

NotImplementedError – If the eval_fn parameter is not specified and the _eval method is not defined in a derived class.

__call__(x)[source]#

Evaluate this Operator at the point \(\mb{x}\).

Parameters:

x (Union[Operator, Array, BlockArray]) – Point at which to evaluate this Operator. If x is a jax.Array or BlockArray, it must have shape == self.input_shape. If x is a Operator or LinearOperator, it must have x.output_shape == self.input_shape.

Return type:

Union[Operator, Array, BlockArray]

Returns:

Operator evaluated at x.

Raises:

ValueError – If the input_shape attribute of the Operator is not equal to the input array shape, or to the output_shape attribute of another Operator with which it is composed.

freeze(argnum, val)[source]#

Return a new Operator with fixed block argument.

Return a new Operator with block argument argnum fixed to value val.

Parameters:
  • argnum (int) – Index of block to freeze. Must be less than or equal to the number of blocks in an input array.

  • val (Union[Array, BlockArray]) – Value to fix the argnum-th input to.

Return type:

Operator

Returns:

A new Operator with one of the blocks of the input fixed to the specified value.

Raises:

ValueError – If the Operator does not take a BlockArray as its input, if the block index equals or exceeds the number of blocks, or if the shape of the fixed value differs from the shape of the specified block.

jit()[source]#

Activate just-in-time compilation for the _eval method.

jvp(u, v)[source]#

Compute a Jacobian-vector product.

Compute the product \(J_F(\mb{u}) \mb{v}\) where \(F\) represents this operator and \(J_F(\mb{u})\) is the Jacobian of \(F\) evaluated at \(\mb{u}\). This method is implemented via a call to jax.jvp.

Parameters:
  • u – Value at which the Jacobian is evaluated.

  • v – Vector in the Jacobian-vector product.

Returns:

A pair \((F(\mb{u}), J_F(\mb{u}) \mb{v})\), i.e. a pair consisting of the operator evaluated at \(\mb{u}\) and the Jacobian-vector product.

vjp(u, conjugate=True)[source]#

Compute a vector-Jacobian product.

Compute the product \([J_F(\mb{u})]^T \mb{v}\) where \(F\) represents this operator and \(J_F(\mb{u})\) is the Jacobian of \(F\) evaluated at \(\mb{u}\). Instead of directly computing the vector-Jacobian product, this method returns a function, taking \(\mb{v}\) as an argument, that returns the product. This method is implemented via a call to jax.vjp.

Parameters:
  • u – Value at which the Jacobian is evaluated.

  • conjugate – If True, compute the product using the conjugate (Hermitian) transpose.

Returns:

A pair \((F(\mb{u}), G(\cdot))\) where \(G(\cdot)\) is a function that computes the vector-Jacobian product, i.e. \(G(\mb{v}) = [J_F(\mb{u})]^T \mb{v}\) when conjugate is False, or \(G(\mb{v}) = [J_F(\mb{u})]^H \mb{v}\) when conjugate is True.

class scico.operator.BiConvolve(input_shape, input_dtype=<class 'numpy.float32'>, mode='full', jit=True)#

Bases: Operator

Biconvolution operator.

Inheritance diagram of BiConvolve

A BiConvolve operator accepts a BlockArray input with two blocks of equal ndims, and convolves the first block with the second.

If A is a BiConvolve operator, then A(snp.blockarray([x, h])) equals jax.scipy.signal.convolve(x, h).

Parameters:
  • input_shape (Tuple[Tuple[int, ...], Tuple[int, ...]]) – Shape of input BlockArray. Must correspond to a `BlockArray with two blocks of equal ndims.

  • input_dtype (DType) – dtype for input argument. Defaults to float32.

  • mode (str) – A string indicating the size of the output. One of “full”, “valid”, “same”. Defaults to “full”.

  • jit (bool) – If True, jit the evaluation of this Operator.

For more details on mode, see jax.scipy.signal.convolve.

freeze(argnum, val)[source]#

Freeze the argnum parameter.

Return a new LinearOperator with block argument argnum fixed to value val.

If argnum == 0, a ConvolveByX object is returned. If argnum == 1, a Convolve object is returned.

Parameters:
  • argnum (int) – Index of block to freeze. Must be 0 or 1.

  • val (Array) – Value to fix the argnum-th input to.

Return type:

LinearOperator

scico.operator.operator_from_function(f, classname, f_name=None)#

Make an Operator from a function.

Example

>>> AbsVal = operator_from_function(snp.abs, 'AbsVal')
>>> H = AbsVal((2,))
>>> H(snp.array([1.0, -1.0]))
Array([1., 1.], dtype=float32)
Parameters:
  • f (Callable) – Function from which to create an Operator.

  • classname (str) – Name of the resulting class.

  • f_name (Optional[str]) – Name of f for use in docstrings. Useful for getting the correct version of wrapped functions. Defaults to f”{f.__module__}.{f.__name__}”.

class scico.operator.Abs(input_shape, *args, input_dtype=<class 'jax.numpy.float32'>, output_shape=None, output_dtype=None, jit=True, **kwargs)#

Bases: Operator

Operator version of scico.numpy.abs.

Inheritance diagram of Abs

Parameters:
  • input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]]) – Shape of input array.

  • args (Any) – Positional arguments passed to scico.numpy.abs.

  • input_dtype (DType) – dtype for input argument. Defaults to float32. If the Operator implements complex-valued operations, this must be a complex dtype (typically complex64) for correct adjoint and gradient calculation.

  • output_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...], None]) – Shape of output array. Defaults to None. If None, output_shape is determined by evaluating self.__call__ on an input array of zeros.

  • output_dtype (Optional[DType]) – dtype for output argument. Defaults to None. If None, output_dtype is determined by evaluating self.__call__ on an input array of zeros.

  • jit (bool) – If True, call Operator.jit on this Operator to jit the forward, adjoint, and gram functions. Same as calling Operator.jit after the Operator is created.

  • kwargs (Any) – Keyword arguments passed to scico.numpy.abs.

class scico.operator.Angle(input_shape, *args, input_dtype=<class 'jax.numpy.float32'>, output_shape=None, output_dtype=None, jit=True, **kwargs)#

Bases: Operator

Operator version of scico.numpy.angle.

Inheritance diagram of Angle

Parameters:
  • input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]]) – Shape of input array.

  • args (Any) – Positional arguments passed to scico.numpy.angle.

  • input_dtype (DType) – dtype for input argument. Defaults to float32. If the Operator implements complex-valued operations, this must be a complex dtype (typically complex64) for correct adjoint and gradient calculation.

  • output_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...], None]) – Shape of output array. Defaults to None. If None, output_shape is determined by evaluating self.__call__ on an input array of zeros.

  • output_dtype (Optional[DType]) – dtype for output argument. Defaults to None. If None, output_dtype is determined by evaluating self.__call__ on an input array of zeros.

  • jit (bool) – If True, call Operator.jit on this Operator to jit the forward, adjoint, and gram functions. Same as calling Operator.jit after the Operator is created.

  • kwargs (Any) – Keyword arguments passed to scico.numpy.angle.

class scico.operator.Exp(input_shape, *args, input_dtype=<class 'jax.numpy.float32'>, output_shape=None, output_dtype=None, jit=True, **kwargs)#

Bases: Operator

Operator version of scico.numpy.exp.

Inheritance diagram of Exp

Parameters:
  • input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...]]) – Shape of input array.

  • args (Any) – Positional arguments passed to scico.numpy.exp.

  • input_dtype (DType) – dtype for input argument. Defaults to float32. If the Operator implements complex-valued operations, this must be a complex dtype (typically complex64) for correct adjoint and gradient calculation.

  • output_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...], None]) – Shape of output array. Defaults to None. If None, output_shape is determined by evaluating self.__call__ on an input array of zeros.

  • output_dtype (Optional[DType]) – dtype for output argument. Defaults to None. If None, output_dtype is determined by evaluating self.__call__ on an input array of zeros.

  • jit (bool) – If True, call Operator.jit on this Operator to jit the forward, adjoint, and gram functions. Same as calling Operator.jit after the Operator is created.

  • kwargs (Any) – Keyword arguments passed to scico.numpy.exp.

class scico.operator.DiagonalStack(ops, collapse_input=True, collapse_output=True, jit=True, **kwargs)#

Bases: Operator

A diagonal stack of operators.

Inheritance diagram of DiagonalStack

Given operators \(A_1, A_2, \dots, A_N\), create the operator \(H\) such that

\[\begin{split}H \left( \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} \right) = \begin{pmatrix} A_1(\mb{x}_1) \\ A_2(\mb{x}_2) \\ \vdots \\ A_N(\mb{x}_N) \\ \end{pmatrix} \;.\end{split}\]

By default, if the inputs \(\mb{x}_1, \mb{x}_2, \dots, \mb{x}_N\) all have the same (possibly nested) shape, S, this operator will work on the stack, i.e., have an input shape of (N, *S). If the inputs have distinct shapes, S1, S2, …, SN, this operator will work on the block concatenation, i.e., have an input shape of (S1, S2, …, SN). The same holds for the output shape.

Parameters:
  • ops (Sequence[Operator]) – Operators to stack.

  • collapse_input (Optional[bool]) – If True, inputs are expected to be stacked along the first dimension when possible.

  • collapse_output (Optional[bool]) – If True, the output will be stacked along the first dimension when possible.

  • jit (bool) – See jit in Operator.

static check_if_stackable(ops)[source]#

Check that input ops are suitable for stack creation.

class scico.operator.VerticalStack(ops, collapse_output=True, jit=True, **kwargs)#

Bases: Operator

A vertical stack of operators.

Inheritance diagram of VerticalStack

Given operators \(A_1, A_2, \dots, A_N\), create the operator \(H\) such that

\[\begin{split}H(\mb{x}) = \begin{pmatrix} A_1(\mb{x}) \\ A_2(\mb{x}) \\ \vdots \\ A_N(\mb{x}) \\ \end{pmatrix} \;.\end{split}\]
Parameters:
  • ops (Sequence[Operator]) – Operators to stack.

  • collapse_output (Optional[bool]) – If True and the output would be a BlockArray with shape ((m, n, …), (m, n, …), …), the output is instead a jax.Array with shape (S, m, n, …) where S is the length of ops.

  • jit (bool) – See jit in Operator.

static check_if_stackable(ops)[source]#

Check that input ops are suitable for stack creation.

Modules

scico.operator.biconvolve

Biconvolution operator.