scico.flax.inverse¶
Flax implementation of different imaging inversion models.
Functions
|
Conjugate gradient solver. |
|
Compute largest eigenvalue of a diagonalizable |
Classes
|
Flax implementation of ODP gradient descent with \(\ell_2\) loss block. |
|
Flax implementation of ODP proximal gradient deconvolution block. |
|
Flax implementation of ODP proximal gradient denoise block. |
- scico.flax.inverse.cg_solver(A, b, x0=None, maxiter=50)[source]¶
Conjugate gradient solver.
Solve the linear system \(A\mb{x} = \mb{b}\), where \(A\) is positive definite, via the conjugate gradient method. This is a light version constructed to be differentiable with the autograd functionality from jax. Therefore, (i) it uses
jax.lax.scan
to execute a fixed number of iterations and (ii) it assumes that the linear operator may usejax.pure_callback
. Due to the utilization of a while cycle,scico.cg
is not differentiable by jax andjax.scipy.sparse.linalg.cg
does not support functions usingjax.pure_callback
, which is why an additional conjugate gradient function has been implemented.
- class scico.flax.inverse.ODPProxDnBlock(operator, depth, channels, num_filters, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.2, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Flax implementation of ODP proximal gradient denoise block.
Flax implementation of the unrolled optimization with deep priors (ODP) proximal gradient block for denoising [19].
- Parameters:
operator (
Any
) – Operator for computing forward and adjoint mappings. In this case it corresponds to the identity operator and is used at the network level.depth (
int
) – Number of layers in block.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters.alpha_ini (
float
) – Initial value of the fidelity weight alpha.
- class scico.flax.inverse.ODPProxDcnvBlock(operator, depth, channels, num_filters, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.99, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Flax implementation of ODP proximal gradient deconvolution block.
Flax implementation of the unrolled optimization with deep priors (ODP) proximal gradient block for deconvolution under Gaussian noise [19].
- Parameters:
operator (
Any
) – Operator for computing forward and adjoint mappings. In this case it correponds to a circular convolution operator.depth (
int
) – Number of layers in block.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters.alpha_ini (
float
) – Initial value of the fidelity weight alpha.
- setup()[source]¶
Computing operator norm and setting operator for batch evaluation and defining network layers.
- class scico.flax.inverse.ODPGrDescBlock(operator, depth, channels, num_filters, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.2, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)[source]¶
Bases:
Module
Flax implementation of ODP gradient descent with \(\ell_2\) loss block.
Flax implementation of the unrolled optimization with deep priors (ODP) gradient descent block for inversion using \(\ell_2\) loss described in [19].
- Parameters:
operator (
Any
) – Operator for computing forward and adjoint mappings. In this case it corresponds to the identity operator and is used at the network level.depth (
int
) – Number of layers in block.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters.alpha_ini (
float
) – Initial value of the fidelity weight alpha.
- scico.flax.inverse.power_iteration(A, maxiter=100)[source]¶
Compute largest eigenvalue of a diagonalizable
LinearOperator
.Compute largest eigenvalue of a diagonalizable
LinearOperator
using power iteration. This function has the same functionality aslinop.power_iteration
but is implemented using lax operations to allow jitting and general jax function composition.- Parameters:
A (
LinearOperator
) –LinearOperator
used for computation. Must be diagonalizable.maxiter (
int
) – Maximum number of power iterations to use.
- Returns:
tuple –
A tuple (mu, v) containing:
mu: Estimate of largest eigenvalue of A.
v: Eigenvector of A with eigenvalue mu.