PPP (with BM4D) Volume Deconvolution#

This example demonstrates the solution of a 3D image deconvolution problem (involving recovering a 3D volume that has been convolved with a 3D kernel and corrupted by noise) using the ADMM Plug-and-Play Priors (PPP) algorithm [50], with the BM4D [36] denoiser.

[1]:
import numpy as np

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot, random
from scico.examples import create_3d_foam_phantom, downsample_volume, tile_volume_slices
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()

Create a ground truth image.

[2]:
np.random.seed(1234)
N = 128  # phantom size
Nx, Ny, Nz = N, N, N // 4
upsamp = 2
x_gt_hires = create_3d_foam_phantom((upsamp * Nz, upsamp * Ny, upsamp * Nx), N_sphere=100)
x_gt = downsample_volume(x_gt_hires, upsamp)
x_gt = snp.array(x_gt)  # convert to jax array

Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise.

[3]:
n = 5  # convolution kernel size
σ = 20.0 / 255  # noise level

psf = snp.ones((n, n, n)) / (n**3)
A = linop.Convolve(h=psf, input_shape=x_gt.shape)

Ax = A(x_gt)  # blurred image
noise, key = random.randn(Ax.shape)
y = Ax + σ * noise

Set up ADMM solver.

[4]:
f = loss.SquaredL2Loss(y=y, A=A)
C = linop.Identity(x_gt.shape)

λ = 40.0 / 255  # BM4D regularization strength
g = λ * functional.BM4D()

ρ = 1.0  # ADMM penalty parameter
maxiter = 10  # number of ADMM iterations

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[ρ],
    x0=A.T @ y,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
    itstat_options={"display": True},
)

Run the solver.

[5]:
print(f"Solving on {device_info()}\n")
x = solver.solve()
x = snp.clip(x, 0, 1)
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Prml Rsdl  Dual Rsdl  CG It  CG Res
------------------------------------------------------
   0  3.16e+01  7.594e+00  2.584e+01      3  3.616e-04
   1  6.01e+01  3.864e+00  1.773e+01      3  3.775e-04
   2  8.68e+01  2.389e+00  1.264e+01      3  2.436e-04
   3  1.10e+02  1.826e+00  9.431e+00      2  9.220e-04
   4  1.34e+02  1.634e+00  7.319e+00      2  6.801e-04
   5  1.59e+02  1.596e+00  5.925e+00      2  5.132e-04
   6  1.81e+02  1.507e+00  4.923e+00      2  3.967e-04
   7  2.01e+02  1.442e+00  4.189e+00      2  3.075e-04
   8  2.20e+02  1.330e+00  3.614e+00      2  2.544e-04
   9  2.39e+02  1.303e+00  3.201e+00      2  2.171e-04

Show slices of the recovered 3D volume.

[6]:
show_id = Nz // 2
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(tile_volume_slices(x_gt), title="Ground truth", fig=fig, ax=ax[0])
nc = n // 2
yc = y[nc:-nc, nc:-nc, nc:-nc]
yc = snp.clip(yc, 0, 1)
plot.imview(
    tile_volume_slices(yc),
    title="Slices of blurred, noisy volume: %.2f (dB)" % metric.psnr(x_gt, yc),
    fig=fig,
    ax=ax[1],
)
plot.imview(
    tile_volume_slices(x),
    title="Slices of deconvolved volume: %.2f (dB)" % metric.psnr(x_gt, x),
    fig=fig,
    ax=ax[2],
)
fig.show()
../_images/examples_deconv_ppp_bm4d_admm_11_0.png

Plot convergence statistics.

[7]:
plot.plot(
    snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
    ptyp="semilogy",
    title="Residuals",
    xlbl="Iteration",
    lgnd=("Primal", "Dual"),
)
../_images/examples_deconv_ppp_bm4d_admm_13_0.png