scico.flax.inverse#

Flax implementation of different imaging inversion models.

Functions

cg_solver(A, b[, x0, maxiter])

Conjugate gradient solver.

power_iteration(A[, maxiter])

Compute largest eigenvalue of a diagonalizable LinearOperator.

Classes

ODPGrDescBlock(operator, depth, channels, ...)

Flax implementation of ODP gradient descent with \(\ell_2\) loss block.

ODPProxDcnvBlock(operator, depth, channels, ...)

Flax implementation of ODP proximal gradient deconvolution block.

ODPProxDnBlock(operator, depth, channels, ...)

Flax implementation of ODP proximal gradient denoise block.

scico.flax.inverse.cg_solver(A, b, x0=None, maxiter=50)[source]#

Conjugate gradient solver.

Solve the linear system \(A\mb{x} = \mb{b}\), where \(A\) is positive definite, via the conjugate gradient method. This is a light version constructed to be differentiable with the autograd functionality from jax. Therefore, (i) it uses jax.lax.scan to execute a fixed number of iterations and (ii) it assumes that the linear operator may use jax.pure_callback. Due to the utilization of a while cycle, scico.cg is not differentiable by jax and jax.scipy.sparse.linalg.cg does not support functions using jax.pure_callback, which is why an additional conjugate gradient function has been implemented.

Parameters:
  • A (Callable) – Function implementing linear operator \(A\), should be positive definite.

  • b (Array) – Input array \(\mb{b}\).

  • x0 (Optional[Array]) – Initial solution. Default: None.

  • maxiter (int) – Maximum iterations. Default: 50.

Return type:

Array

Returns:

x – Solution array.

class scico.flax.inverse.ODPProxDnBlock(operator, depth, channels, num_filters, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.2, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Flax implementation of ODP proximal gradient denoise block.

Inheritance diagram of ODPProxDnBlock

Flax implementation of the unrolled optimization with deep priors (ODP) proximal gradient block for denoising [19].

Parameters:
  • operator (Any) – Operator for computing forward and adjoint mappings. In this case it corresponds to the identity operator and is used at the network level.

  • depth (int) – Number of layers in block.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – Convolution strides. Default: (1, 1).

  • alpha_ini (float) – Initial value of the fidelity weight alpha. Default: 0.2.

  • dtype (Any) – Output dtype. Default: float32.

batch_op_adj(y)[source]#

Batch application of adjoint operator.

Return type:

Array

__call__(x, y, train=True)[source]#

Apply denoising block.

Parameters:
  • x (Array) – The array with current stage of denoised signal.

  • y (Array) – The array with noisy signal.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The block output (i.e. next stage of denoised signal).

class scico.flax.inverse.ODPProxDcnvBlock(operator, depth, channels, num_filters, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.99, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Flax implementation of ODP proximal gradient deconvolution block.

Inheritance diagram of ODPProxDcnvBlock

Flax implementation of the unrolled optimization with deep priors (ODP) proximal gradient block for deconvolution under Gaussian noise [19].

Parameters:
  • operator (Any) – Operator for computing forward and adjoint mappings. In this case it correponds to a circular convolution operator.

  • depth (int) – Number of layers in block.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – Convolution strides. Default: (1, 1).

  • alpha_ini (float) – Initial value of the fidelity weight alpha. Default: 0.99.

  • dtype (Any) – Output dtype. Default: float32.

setup()[source]#

Computing operator norm and setting operator for batch evaluation and defining network layers.

batch_op_adj(y)[source]#

Batch application of adjoint operator.

Return type:

Array

__call__(x, y, train=True)[source]#

Apply debluring block.

Parameters:
  • x (Array) – The array with current stage of reconstructed signal.

  • y (Array) – The array with signal to invert.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The block output (i.e. next stage of reconstructed signal).

class scico.flax.inverse.ODPGrDescBlock(operator, depth, channels, num_filters, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.2, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Flax implementation of ODP gradient descent with \(\ell_2\) loss block.

Inheritance diagram of ODPGrDescBlock

Flax implementation of the unrolled optimization with deep priors (ODP) gradient descent block for inversion using \(\ell_2\) loss described in [19].

Parameters:
  • operator (Any) – Operator for computing forward and adjoint mappings. In this case it corresponds to the identity operator and is used at the network level.

  • depth (int) – Number of layers in block.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – Convolution strides. Default: (1, 1).

  • alpha_ini (float) – Initial value of the fidelity weight alpha. Default: 0.2.

  • dtype (Any) – Output dtype. Default: float32.

setup()[source]#

Setting operator for batch evaluation and defining network layers.

batch_op_adj(y)[source]#

Batch application of adjoint operator.

Return type:

Array

__call__(x, y, train=True)[source]#

Apply gradient descent block.

Parameters:
  • x (Array) – The array with current stage of reconstructed signal.

  • y (Array) – The array with signal to invert.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The block output (i.e. next stage of inverted signal).

scico.flax.inverse.power_iteration(A, maxiter=100)[source]#

Compute largest eigenvalue of a diagonalizable LinearOperator.

Compute largest eigenvalue of a diagonalizable LinearOperator using power iteration. This function has the same functionality as linop.power_iteration but is implemented using lax operations to allow jitting and general jax function composition.

Parameters:
  • A (LinearOperator) – LinearOperator used for computation. Must be diagonalizable.

  • maxiter (int) – Maximum number of power iterations to use.

Returns:

tuple

A tuple (mu, v) containing:

  • mu: Estimate of largest eigenvalue of A.

  • v: Eigenvector of A with eigenvalue mu.