# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Flax implementation of different imaging inversion models."""
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
from functools import partial
from typing import Any, Callable, Tuple
import jax.numpy as jnp
from jax import jit, lax, random
from flax.core import Scope # noqa
from flax.linen.module import _Sentinel # noqa
from flax.linen.module import Module, compact
from scico.flax import ResNet
from scico.linop import LinearOperator
from scico.numpy import Array
from scico.typing import DType, PRNGKey, Shape
# The imports of Scope and _Sentinel (above) are required to silence
# "cannot resolve forward reference" warnings when building sphinx api
# docs.
ModuleDef = Any
class MoDLNet(Module):
"""Flax implementation of MoDL :cite:`aggarwal-2019-modl`.
Flax implementation of the model-based deep learning (MoDL)
architecture for inverse problems described in :cite:`aggarwal-2019-modl`.
Args:
operator: Operator for computing forward and adjoint mappings.
depth: Depth of MoDL net. Default: 1.
channels: Number of channels of input tensor.
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the output
tensor.
block_depth: Number of layers in the computational block.
kernel_size: Size of the convolution filters. Default: (3, 3).
strides: Convolution strides. Default: (1, 1).
lmbda_ini: Initial value of the regularization weight `lambda`.
Default: 0.5.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
cg_iter: Number of iterations for cg solver. Default: 10.
"""
operator: ModuleDef
depth: int
channels: int
num_filters: int
block_depth: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
lmbda_ini: float = 0.5
dtype: Any = jnp.float32
cg_iter: int = 10
[docs] @compact
def __call__(self, y: Array, train: bool = True) -> Array:
"""Apply MoDL net for inversion.
Args:
y: The array with signal to invert.
train: Flag to differentiate between training and testing
stages.
Returns:
The reconstructed signal.
"""
def lmbda_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:
return jnp.ones(shape, dtype) * self.lmbda_ini
lmbda = self.param("lmbda", lmbda_init_wrap, (1,))
resnet = ResNet(
self.block_depth,
self.channels,
self.num_filters,
self.kernel_size,
self.strides,
dtype=self.dtype,
)
ah_f = lambda v: jnp.atleast_3d(self.operator.adj(v.reshape(self.operator.output_shape)))
Ahb = lax.map(ah_f, y)
x = Ahb
ahaI_f = lambda v: self.operator.adj(self.operator(v)) + lmbda * v
cgsol = lambda b: jnp.atleast_3d(
cg_solver(ahaI_f, b.reshape(self.operator.input_shape), maxiter=self.cg_iter)
)
for i in range(self.depth):
z = resnet(x, train)
# Solve:
# (AH A + lmbda I) x = Ahb + lmbda * z
b = Ahb + lmbda * z
x = lax.map(cgsol, b)
return x
[docs]def cg_solver(A: Callable, b: Array, x0: Array = None, maxiter: int = 50) -> Array:
r"""Conjugate gradient solver.
Solve the linear system :math:`A\mb{x} = \mb{b}`, where :math:`A` is
positive definite, via the conjugate gradient method. This is a light
version constructed to be differentiable with the autograd
functionality from jax. Therefore, (i) it uses :meth:`jax.lax.scan`
to execute a fixed number of iterations and (ii) it assumes that the
linear operator may use :meth:`jax.pure_callback`. Due to the
utilization of a while cycle, :meth:`scico.cg` is not differentiable
by jax and :meth:`jax.scipy.sparse.linalg.cg` does not support
functions using :meth:`jax.pure_callback`, which is why an additional
conjugate gradient function has been implemented.
Args:
A: Function implementing linear operator :math:`A`, should be
positive definite.
b: Input array :math:`\mb{b}`.
x0: Initial solution. Default: ``None``.
maxiter: Maximum iterations. Default: 50.
Returns:
x: Solution array.
"""
def fun(carry, _):
"""Function implementing one iteration of the conjugate gradient solver."""
x, r, p, num = carry
Ap = A(p)
alpha = num / (p.ravel().conj().T @ Ap.ravel())
x = x + alpha * p
r = r - alpha * Ap
num_old = num
num = r.ravel().conj().T @ r.ravel()
beta = num / num_old
p = r + beta * p
return (x, r, p, num), None
if x0 is None:
x0 = jnp.zeros_like(b)
r0 = b - A(x0)
num0 = r0.ravel().conj().T @ r0.ravel()
carry = (x0, r0, r0, num0)
carry, _ = lax.scan(fun, carry, xs=None, length=maxiter)
return carry[0]
[docs]class ODPProxDnBlock(Module):
"""Flax implementation of ODP proximal gradient denoise block.
Flax implementation of the unrolled optimization with deep priors
(ODP) proximal gradient block for denoising :cite:`diamond-2018-odp`.
Args:
operator: Operator for computing forward and adjoint mappings.
In this case it corresponds to the identity operator and is
used at the network level.
depth: Number of layers in block.
channels: Number of channels of input tensor.
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the output
tensor.
kernel_size: Size of the convolution filters. Default: (3, 3).
strides: Convolution strides. Default: (1, 1).
alpha_ini: Initial value of the fidelity weight `alpha`.
Default: 0.2.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
operator: ModuleDef
depth: int
channels: int
num_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
alpha_ini: float = 0.2
dtype: Any = jnp.float32
[docs] def batch_op_adj(self, y: Array) -> Array:
"""Batch application of adjoint operator."""
return self.operator.adj(y)
[docs] @compact
def __call__(self, x: Array, y: Array, train: bool = True) -> Array:
"""Apply denoising block.
Args:
x: The array with current stage of denoised signal.
y: The array with noisy signal.
train: Flag to differentiate between training and testing
stages.
Returns:
The block output (i.e. next stage of denoised signal).
"""
def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:
return jnp.ones(shape, dtype) * self.alpha_ini
alpha = self.param("alpha", alpha_init_wrap, (1,))
resnet = ResNet(
self.depth,
self.channels,
self.num_filters,
self.kernel_size,
self.strides,
dtype=self.dtype,
)
x = (resnet(x, train) + y * alpha) / (1.0 + alpha)
return x
[docs]class ODPProxDcnvBlock(Module):
"""Flax implementation of ODP proximal gradient deconvolution block.
Flax implementation of the unrolled optimization with deep priors
(ODP) proximal gradient block for deconvolution under Gaussian noise
:cite:`diamond-2018-odp`.
Args:
operator: Operator for computing forward and adjoint mappings.
In this case it correponds to a circular convolution operator.
depth: Number of layers in block.
channels: Number of channels of input tensor.
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the output
tensor.
kernel_size: Size of the convolution filters. Default: (3, 3).
strides: Convolution strides. Default: (1, 1).
alpha_ini: Initial value of the fidelity weight `alpha`.
Default: 0.99.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
operator: ModuleDef
depth: int
channels: int
num_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
alpha_ini: float = 0.99
dtype: Any = jnp.float32
[docs] def setup(self):
"""Computing operator norm and setting operator for batch
evaluation and defining network layers."""
self.operator_norm = jnp.sqrt(power_iteration(self.operator.H @ self.operator)[0].real)
self.ah_f = lambda v: jnp.atleast_3d(
self.operator.adj(v.reshape(self.operator.output_shape))
)
self.resnet = ResNet(
self.depth,
self.channels,
self.num_filters,
self.kernel_size,
self.strides,
dtype=self.dtype,
)
def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:
return jnp.ones(shape, dtype) * self.alpha_ini
self.alpha = self.param("alpha", alpha_init_wrap, (1,))
[docs] def batch_op_adj(self, y: Array) -> Array:
"""Batch application of adjoint operator."""
return lax.map(self.ah_f, y)
[docs] def __call__(self, x: Array, y: Array, train: bool = True) -> Array:
"""Apply debluring block.
Args:
x: The array with current stage of reconstructed signal.
y: The array with signal to invert.
train: Flag to differentiate between training and testing
stages.
Returns:
The block output (i.e. next stage of reconstructed signal).
"""
# DFT over spatial dimensions
fft_shape: Shape = x.shape[1:-1]
fft_axes: Tuple[int, int] = (1, 2)
scale = 1.0 / (self.alpha * self.operator_norm**2 + 1)
x = jnp.fft.irfftn(
jnp.fft.rfftn(
self.alpha * self.batch_op_adj(y) + self.resnet(x, train),
s=fft_shape,
axes=fft_axes,
)
/ scale,
s=fft_shape,
axes=fft_axes,
)
return x
[docs]class ODPGrDescBlock(Module):
r"""Flax implementation of ODP gradient descent with :math:`\ell_2` loss block.
Flax implementation of the unrolled optimization with deep priors
(ODP) gradient descent block for inversion using :math:`\ell_2` loss
described in :cite:`diamond-2018-odp`.
Args:
operator: Operator for computing forward and adjoint mappings. In
this case it corresponds to the identity operator and is used
at the network level.
depth: Number of layers in block.
channels: Number of channels of input tensor.
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the output
tensor.
kernel_size: Size of the convolution filters. Default: (3, 3).
strides: Convolution strides. Default: (1, 1).
alpha_ini: Initial value of the fidelity weight `alpha`.
Default: 0.2.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
operator: ModuleDef
depth: int
channels: int
num_filters: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
alpha_ini: float = 0.2
dtype: Any = jnp.float32
[docs] def setup(self):
"""Setting operator for batch evaluation and defining network layers."""
self.ah_f = lambda v: jnp.atleast_3d(
self.operator.adj(v.reshape(self.operator.output_shape))
)
self.a_f = lambda v: jnp.atleast_3d(self.operator(v.reshape(self.operator.input_shape)))
self.resnet = ResNet(
self.depth,
self.channels,
self.num_filters,
self.kernel_size,
self.strides,
dtype=self.dtype,
)
def alpha_init_wrap(rng: PRNGKey, shape: Shape, dtype: DType = self.dtype) -> Array:
return jnp.ones(shape, dtype) * self.alpha_ini
self.alpha = self.param("alpha", alpha_init_wrap, (1,))
[docs] def batch_op_adj(self, y: Array) -> Array:
"""Batch application of adjoint operator."""
return lax.map(self.ah_f, y)
[docs] def __call__(self, x: Array, y: Array, train: bool = True) -> Array:
"""Apply gradient descent block.
Args:
x: The array with current stage of reconstructed signal.
y: The array with signal to invert.
train: Flag to differentiate between training and testing
stages.
Returns:
The block output (i.e. next stage of inverted signal).
"""
x = self.resnet(x, train) - self.alpha * self.batch_op_adj(lax.map(self.a_f, x) - y)
return x
class ODPNet(Module):
"""Flax implementation of ODP network :cite:`diamond-2018-odp`.
Flax implementation of the unrolled optimization with deep priors
(ODP) network for inverse problems described in
:cite:`diamond-2018-odp`. It can be constructed with proximal gradient
blocks or gradient descent blocks.
Args:
operator: Operator for computing forward and adjoint mappings.
depth: Depth of MoDL net. Default: 1.
channels: Number of channels of input tensor.
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the output
tensor.
block_depth: Number of layers in the computational block.
kernel_size: Size of the convolution filters. Default: (3, 3).
strides: Convolution strides. Default: (1, 1).
alpha_ini: Initial value of the fidelity weight `alpha`.
Default: 0.5.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
odp_block: processing block to apply. Default
:class:`ODPProxDnBlock`.
"""
operator: ModuleDef
depth: int
channels: int
num_filters: int
block_depth: int
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
alpha_ini: float = 0.5
dtype: Any = jnp.float32
odp_block: Callable = ODPProxDnBlock
[docs] @compact
def __call__(self, y: Array, train: bool = True) -> Array:
"""Apply ODP net for inversion.
Args:
y: The array with signal to invert.
train: Flag to differentiate between training and testing
stages.
Returns:
The reconstructed signal.
"""
block = partial(
self.odp_block,
operator=self.operator,
depth=self.block_depth,
channels=self.channels,
num_filters=self.num_filters,
kernel_size=self.kernel_size,
strides=self.strides,
dtype=self.dtype,
)
# Initial block handles initial inversion.
# Not all operators are batch-ready.
alpha0_i = self.alpha_ini
block0 = block(alpha_ini=alpha0_i)
x = block0.batch_op_adj(y)
x = block0(x, y, train)
alpha0_i /= 2.0
for i in range(self.depth - 1):
x = block(alpha_ini=alpha0_i)(x, y, train)
alpha0_i /= 2.0
return x
[docs]@partial(jit, static_argnums=0)
def power_iteration(A: LinearOperator, maxiter: int = 100):
"""Compute largest eigenvalue of a diagonalizable :class:`.LinearOperator`.
Compute largest eigenvalue of a diagonalizable :class:`LinearOperator`
using power iteration. This function has the same functionality as
:class:`.linop.power_iteration` but is implemented using lax operations to
allow jitting and general jax function composition.
Args:
A: :class:`LinearOperator` used for computation. Must be diagonalizable.
maxiter: Maximum number of power iterations to use.
Returns:
tuple: A tuple (`mu`, `v`) containing:
- **mu**: Estimate of largest eigenvalue of `A`.
- **v**: Eigenvector of `A` with eigenvalue `mu`.
"""
key = random.PRNGKey(0)
v = random.normal(key, shape=A.input_shape, dtype=A.input_dtype)
v = v / jnp.linalg.norm(v)
init_val = (0, v, v, 1.0)
def cond_fun(val):
return jnp.logical_and(val[0] <= maxiter, val[3] > 0.0)
def body_fun(val):
i, v, Av, normAv = val
v = Av / normAv
i = i + 1
Av = A @ v
normAv = jnp.linalg.norm(Av)
return (i, v, Av, normAv)
def true_fun(v, Av, normAv):
return jnp.sum(v.conj() * Av) / jnp.linalg.norm(v) ** 2, Av / normAv
def false_fun(v, Av, normAv):
return 0.0 * normAv, Av # Multiplication by zero used to preserve data type
i, v, Av, normAv = lax.while_loop(cond_fun, body_fun, init_val)
mu, v = lax.cond(normAv > 0.0, true_fun, false_fun, v, Av, normAv)
return mu, v