PPP (with BM3D) Image Deconvolution (APGM Solver)

This example demonstrates the solution of an image deconvolution problem using the APGM Plug-and-Play Priors (PPP) algorithm [37], with the BM3D [17] denoiser.

[1]:
import numpy as np

import komplot as kplt
from xdesign import Foam, discrete_phantom

import scico.numpy as snp
from scico import functional, linop, loss, metric, random
from scico.optimize.pgm import AcceleratedPGM
from scico.util import device_info
kplt.config_notebook_plotting()

Create a ground truth image.

[2]:
np.random.seed(1234)
N = 512  # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = snp.array(x_gt)  # convert to jax array

Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise.

[3]:
n = 5  # convolution kernel size
σ = 20.0 / 255  # noise level

psf = snp.ones((n, n)) / (n * n)
A = linop.Convolve(h=psf, input_shape=x_gt.shape)

Ax = A(x_gt)  # blurred image
noise, key = random.randn(Ax.shape)
y = Ax + σ * noise

Set up PGM solver.

[4]:
f = loss.SquaredL2Loss(y=y, A=A)

L0 = 15  # APGM inverse step size parameter
λ = L0 * 2.0 / 255  # BM3D regularization strength
g = λ * functional.BM3D()

maxiter = 50  # number of APGM iterations

solver = AcceleratedPGM(
    f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, itstat_options={"display": True, "period": 10}
)

Run the solver.

[5]:
print(f"Solving on {device_info()}\n")
x = solver.solve()
x = snp.clip(x, 0, 1)
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      L          Residual
------------------------------------
   0  5.27e+00  1.500e+01  2.068e+00
  10  5.97e+01  1.500e+01  6.043e-01
  20  1.09e+02  1.500e+01  2.108e-01
  30  1.43e+02  1.500e+01  1.500e-01
  40  1.70e+02  1.500e+01  1.425e-01
  49  2.09e+02  1.500e+01  1.363e-01

Show the recovered image.

[6]:
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(15, 5))
kplt.imview(x_gt, cmap="Blues", title="Ground truth", ax=ax[0])
nc = n // 2
yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)
kplt.imview(
    y, cmap="Blues", title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), ax=ax[1]
)
kplt.imview(x, cmap="Blues", title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), ax=ax[2])
fig.show()
../_images/examples_deconv_ppp_bm3d_apgm_11_0.png

Plot convergence statistics.

[7]:
kplt.plot(hist.Residual, ylog=True, title="PGM Residual", xlabel="Iteration", ylabel="Residual")
[7]:
<komplot.LinePlot at 0x762d004e3fe0>
../_images/examples_deconv_ppp_bm3d_apgm_13_1.png