PPP (with BM3D) Image Deconvolution (ADMM Solver)

This example demonstrates the solution of an image deconvolution problem using the ADMM Plug-and-Play Priors (PPP) algorithm [56], with the BM3D [17] denoiser.

[1]:
import numpy as np

import komplot as kplt
from xdesign import Foam, discrete_phantom

import scico.numpy as snp
from scico import functional, linop, loss, metric, random
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
kplt.config_notebook_plotting()

Create a ground truth image.

[2]:
np.random.seed(1234)
N = 512  # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = snp.array(x_gt)  # convert to jax array

Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise.

[3]:
n = 5  # convolution kernel size
σ = 20.0 / 255  # noise level

psf = snp.ones((n, n)) / (n * n)
A = linop.Convolve(h=psf, input_shape=x_gt.shape)

Ax = A(x_gt)  # blurred image
noise, key = random.randn(Ax.shape)
y = Ax + σ * noise

Set up ADMM solver.

[4]:
f = loss.SquaredL2Loss(y=y, A=A)
C = linop.Identity(x_gt.shape)

λ = 20.0 / 255  # BM3D regularization strength
g = λ * functional.BM3D()

ρ = 1.0  # ADMM penalty parameter
maxiter = 10  # number of ADMM iterations

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[ρ],
    x0=A.T @ y,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
    itstat_options={"display": True},
)

Run the solver.

[5]:
print(f"Solving on {device_info()}\n")
x = solver.solve()
x = snp.clip(x, 0, 1)
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Prml Rsdl  Dual Rsdl  CG It  CG Res
------------------------------------------------------
   0  4.31e+00  9.566e+00  1.463e+01      3  2.077e-04
   1  7.06e+00  3.677e+00  9.177e+00      3  2.275e-04
   2  1.07e+01  1.225e+00  6.477e+00      2  6.779e-04
   3  1.48e+01  8.981e-01  4.750e+00      2  4.347e-04
   4  2.02e+01  7.582e-01  3.669e+00      2  3.154e-04
   5  2.46e+01  6.767e-01  2.965e+00      2  2.339e-04
   6  2.88e+01  6.322e-01  2.480e+00      2  1.786e-04
   7  3.56e+01  5.247e-01  2.064e+00      1  8.090e-04
   8  4.28e+01  5.394e-01  1.825e+00      1  5.616e-04
   9  4.98e+01  5.240e-01  1.632e+00      1  5.274e-04

Show the recovered image.

[6]:
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(15, 5))
kplt.imview(x_gt, cmap="Blues", title="Ground truth", ax=ax[0])
nc = n // 2
yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)
kplt.imview(
    y, cmap="Blues", title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), ax=ax[1]
)
kplt.imview(x, cmap="Blues", title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), ax=ax[2])
fig.show()
../_images/examples_deconv_ppp_bm3d_admm_11_0.png

Plot convergence statistics.

[7]:
kplt.plot(
    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
    ylog=True,
    title="Residuals",
    xlabel="Iteration",
    legend=("Primal", "Dual"),
)
[7]:
<komplot.LinePlot at 0x78b96c6e79e0>
../_images/examples_deconv_ppp_bm3d_admm_13_1.png