scico.operator¶
Operator functions and classes.
Modules
Biconvolution operator. |
Functions
|
Make an |
Classes
|
Operator version of |
|
Operator version of |
|
Biconvolution operator. |
|
A diagonal stack constructed from a single operator. |
|
A diagonal stack of operators. |
|
Operator version of |
|
Generic operator class. |
|
A vertical stack of operators. |
- class scico.operator.Operator(input_shape, output_shape=None, eval_fn=None, input_dtype=<class 'numpy.float32'>, output_dtype=None, jit=False)¶
Bases:
objectGeneric operator class.
- Parameters:
input_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...]]) – Shape of input array.output_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...],None]) – Shape of output array. Defaults toNone. IfNone, output_shape is determined by evaluating self.__call__ on an input array of zeros.eval_fn (
Optional[Callable]) – Function used in evaluating thisOperator. Defaults toNone. Required unless __init__ is being called from a derived class with an _eval method.input_dtype (
DType) – dtype for input argument. Defaults tofloat32. If theOperatorimplements complex-valued operations, this must be a complex dtype (typicallycomplex64) for correct adjoint and gradient calculation.output_dtype (
Optional[DType]) – dtype for output argument. Defaults toNone. IfNone, output_dtype is determined by evaluating self.__call__ on an input array of zeros.jit (
bool) – IfTrue, callOperator.jiton thisOperatorto jit the forward, adjoint, and gram functions. Same as callingOperator.jitafter theOperatoris created.
- Raises:
NotImplementedError – If the eval_fn parameter is not specified and the _eval method is not defined in a derived class.
- __call__(x)[source]¶
Evaluate this
Operatorat the point \(\mb{x}\).- Parameters:
x (
Union[Operator,Array,BlockArray]) – Point at which to evaluate thisOperator. If x is ajax.ArrayorBlockArray, it must have shape == self.input_shape. If x is aOperatororLinearOperator, it must have x.output_shape == self.input_shape.- Return type:
Union[Operator,Array,BlockArray]- Returns:
Operatorevaluated at x.- Raises:
ValueError – If the input_shape attribute of the
Operatoris not equal to the input array shape, or to the output_shape attribute of anotherOperatorwith which it is composed.
- freeze(argnum, val)[source]¶
Return a new
Operatorwith fixed block argument.Return a new
Operatorwith block argument argnum fixed to value val.- Parameters:
argnum (
int) – Index of block to freeze. Must be less than or equal to the number of blocks in an input array.val (
Union[Array,BlockArray]) – Value to fix the argnum-th input to.
- Return type:
- Returns:
A new
Operatorwith one of the blocks of the input fixed to the specified value.- Raises:
ValueError – If the
Operatordoes not take aBlockArrayas its input, if the block index equals or exceeds the number of blocks, or if the shape of the fixed value differs from the shape of the specified block.
- jvp(u, v)[source]¶
Compute a Jacobian-vector product.
Compute the product \(J_F(\mb{u}) \mb{v}\) where \(F\) represents this operator and \(J_F(\mb{u})\) is the Jacobian of \(F\) evaluated at \(\mb{u}\). This method is implemented via a call to
jax.jvp.- Parameters:
u – Value at which the Jacobian is evaluated.
v – Vector in the Jacobian-vector product.
- Returns:
A pair \((F(\mb{u}), J_F(\mb{u}) \mb{v})\), i.e. a pair consisting of the operator evaluated at \(\mb{u}\) and the Jacobian-vector product.
- vjp(u, conjugate=True)[source]¶
Compute a vector-Jacobian product.
Compute the product \([J_F(\mb{u})]^T \mb{v}\) where \(F\) represents this operator and \(J_F(\mb{u})\) is the Jacobian of \(F\) evaluated at \(\mb{u}\). Instead of directly computing the vector-Jacobian product, this method returns a function, taking \(\mb{v}\) as an argument, that returns the product. This method is implemented via a call to
jax.vjp.- Parameters:
u – Value at which the Jacobian is evaluated.
conjugate – If
True, compute the product using the conjugate (Hermitian) transpose.
- Returns:
A pair \((F(\mb{u}), G(\cdot))\) where \(G(\cdot)\) is a function that computes the vector-Jacobian product, i.e. \(G(\mb{v}) = [J_F(\mb{u})]^T \mb{v}\) when conjugate is
False, or \(G(\mb{v}) = [J_F(\mb{u})]^H \mb{v}\) when conjugate isTrue.
- class scico.operator.BiConvolve(input_shape, input_dtype=<class 'numpy.float32'>, mode='full', jit=True)¶
Bases:
OperatorBiconvolution operator.
A
BiConvolveoperator accepts aBlockArrayinput with two blocks of equal ndims, and convolves the first block with the second.If A is a
BiConvolveoperator, then A(snp.blockarray([x, h])) equals jax.scipy.signal.convolve(x, h).- Parameters:
input_shape (
Tuple[Tuple[int,...],Tuple[int,...]]) – Shape of inputBlockArray. Must correspond to a`BlockArraywith two blocks of equal ndims.input_dtype (
DType) – dtype for input argument. Defaults tofloat32.mode (
str) – A string indicating the size of the output. One of “full”, “valid”, “same”. Defaults to “full”.
For more details on mode, see
jax.scipy.signal.convolve.- freeze(argnum, val)[source]¶
Freeze the argnum parameter.
Return a new
LinearOperatorwith block argument argnum fixed to value val.If argnum == 0, a
ConvolveByXobject is returned. If argnum == 1, aConvolveobject is returned.- Parameters:
- Return type:
- scico.operator.operator_from_function(f, classname, f_name=None)¶
Make an
Operatorfrom a function.Example
>>> AbsVal = operator_from_function(snp.abs, 'AbsVal') >>> H = AbsVal((2,)) >>> H(snp.array([1.0, -1.0])) Array([1., 1.], dtype=float32)
- class scico.operator.Abs(input_shape, *args, input_dtype=<class 'jax.numpy.float32'>, output_shape=None, output_dtype=None, jit=True, **kwargs)¶
Bases:
OperatorOperator version of
scico.numpy.abs.- Parameters:
input_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...]]) – Shape of input array.args (
Any) – Positional arguments passed toscico.numpy.abs.input_dtype (
DType) – dtype for input argument. Defaults tofloat32. If theOperatorimplements complex-valued operations, this must be a complex dtype (typicallycomplex64) for correct adjoint and gradient calculation.output_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...],None]) – Shape of output array. Defaults toNone. IfNone, output_shape is determined by evaluating self.__call__ on an input array of zeros.output_dtype (
Optional[DType]) – dtype for output argument. Defaults toNone. IfNone, output_dtype is determined by evaluating self.__call__ on an input array of zeros.jit (
bool) – IfTrue, callOperator.jiton this Operator to jit the forward, adjoint, and gram functions. Same as callingOperator.jitafter theOperatoris created.**kwargs (
Any) – Keyword arguments passed toscico.numpy.abs.
- class scico.operator.Angle(input_shape, *args, input_dtype=<class 'jax.numpy.float32'>, output_shape=None, output_dtype=None, jit=True, **kwargs)¶
Bases:
OperatorOperator version of
scico.numpy.angle.- Parameters:
input_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...]]) – Shape of input array.args (
Any) – Positional arguments passed toscico.numpy.angle.input_dtype (
DType) – dtype for input argument. Defaults tofloat32. If theOperatorimplements complex-valued operations, this must be a complex dtype (typicallycomplex64) for correct adjoint and gradient calculation.output_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...],None]) – Shape of output array. Defaults toNone. IfNone, output_shape is determined by evaluating self.__call__ on an input array of zeros.output_dtype (
Optional[DType]) – dtype for output argument. Defaults toNone. IfNone, output_dtype is determined by evaluating self.__call__ on an input array of zeros.jit (
bool) – IfTrue, callOperator.jiton this Operator to jit the forward, adjoint, and gram functions. Same as callingOperator.jitafter theOperatoris created.**kwargs (
Any) – Keyword arguments passed toscico.numpy.angle.
- class scico.operator.Exp(input_shape, *args, input_dtype=<class 'jax.numpy.float32'>, output_shape=None, output_dtype=None, jit=True, **kwargs)¶
Bases:
OperatorOperator version of
scico.numpy.exp.- Parameters:
input_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...]]) – Shape of input array.args (
Any) – Positional arguments passed toscico.numpy.exp.input_dtype (
DType) – dtype for input argument. Defaults tofloat32. If theOperatorimplements complex-valued operations, this must be a complex dtype (typicallycomplex64) for correct adjoint and gradient calculation.output_shape (
Union[Tuple[int,...],Tuple[Tuple[int,...],...],None]) – Shape of output array. Defaults toNone. IfNone, output_shape is determined by evaluating self.__call__ on an input array of zeros.output_dtype (
Optional[DType]) – dtype for output argument. Defaults toNone. IfNone, output_dtype is determined by evaluating self.__call__ on an input array of zeros.jit (
bool) – IfTrue, callOperator.jiton this Operator to jit the forward, adjoint, and gram functions. Same as callingOperator.jitafter theOperatoris created.**kwargs (
Any) – Keyword arguments passed toscico.numpy.exp.
- class scico.operator.DiagonalStack(ops, collapse_input=True, collapse_output=True, jit=True, **kwargs)¶
Bases:
OperatorA diagonal stack of operators.
Given operators \(A_1, A_2, \dots, A_N\), create the operator \(H\) such that
\[\begin{split}H \left( \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} \right) = \begin{pmatrix} A_1(\mb{x}_1) \\ A_2(\mb{x}_2) \\ \vdots \\ A_N(\mb{x}_N) \\ \end{pmatrix} \;.\end{split}\]By default, if the inputs \(\mb{x}_1, \mb{x}_2, \dots, \mb{x}_N\) all have the same (possibly nested) shape, S, this operator will work on the stack, i.e., have an input shape of (N, *S). If the inputs have distinct shapes, S1, S2, …, SN, this operator will work on the block concatenation, i.e., have an input shape of (S1, S2, …, SN). The same holds for the output shape.
- class scico.operator.VerticalStack(ops, collapse_output=True, jit=True, **kwargs)¶
Bases:
OperatorA vertical stack of operators.
Given operators \(A_1, A_2, \dots, A_N\), create the operator \(H\) such that
\[\begin{split}H(\mb{x}) = \begin{pmatrix} A_1(\mb{x}) \\ A_2(\mb{x}) \\ \vdots \\ A_N(\mb{x}) \\ \end{pmatrix} \;.\end{split}\]- Parameters:
- class scico.operator.DiagonalReplicated(op, replicates, input_axis=0, output_axis=None, map_type='auto', **kwargs)¶
Bases:
OperatorA diagonal stack constructed from a single operator.
Given operator \(A\), create the operator \(H\) such that
\[\begin{split}H \left( \begin{pmatrix} \mb{x}_1 \\ \mb{x}_2 \\ \vdots \\ \mb{x}_N \\ \end{pmatrix} \right) = \begin{pmatrix} A(\mb{x}_1) \\ A(\mb{x}_2) \\ \vdots \\ A(\mb{x}_N) \\ \end{pmatrix} \;.\end{split}\]The application of \(A\) to each component \(\mb{x}_k\) is computed using
jax.pmaporjax.vmap. The input shape for operator \(A\) should exclude the array axis on which \(A\) is replicated to form \(H\). For example, if \(A\) has input shape (3, 4) and \(H\) is constructed to replicate on axis 0 with 2 replicates, the input shape of \(H\) will be (2, 3, 4).Operators taking
BlockArrayinput are not supported.- Parameters:
op (
Operator) – Operator to replicate.replicates (
int) – Number of replicates of op.input_axis (
int) – Input axis over which op should be replicated.output_axis (
Optional[int]) – Index of replication axis in output array. IfNone, the input replication axis is used.map_type (
str) – If “pmap” or “vmap”, apply replicated mapping usingjax.pmaporjax.vmaprespectively. If “auto”, usejax.pmapif sufficient devices are available for the number of replicates, otherwise usejax.vmap.