PPP (with BM3D) Image Deconvolution (APGM Solver)¶
This example demonstrates the solution of an image deconvolution problem using the APGM Plug-and-Play Priors (PPP) algorithm [37], with the BM3D [17] denoiser.
[1]:
import numpy as np
import komplot as kplt
from xdesign import Foam, discrete_phantom
import scico.numpy as snp
from scico import functional, linop, loss, metric, random
from scico.optimize.pgm import AcceleratedPGM
from scico.util import device_info
kplt.config_notebook_plotting()
Create a ground truth image.
[2]:
np.random.seed(1234)
N = 512 # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = snp.array(x_gt) # convert to jax array
Set up forward operator and test signal consisting of blurred signal with additive Gaussian noise.
[3]:
n = 5 # convolution kernel size
σ = 20.0 / 255 # noise level
psf = snp.ones((n, n)) / (n * n)
A = linop.Convolve(h=psf, input_shape=x_gt.shape)
Ax = A(x_gt) # blurred image
noise, key = random.randn(Ax.shape)
y = Ax + σ * noise
Set up PGM solver.
[4]:
f = loss.SquaredL2Loss(y=y, A=A)
L0 = 15 # APGM inverse step size parameter
λ = L0 * 2.0 / 255 # BM3D regularization strength
g = λ * functional.BM3D()
maxiter = 50 # number of APGM iterations
solver = AcceleratedPGM(
f=f, g=g, L0=L0, x0=A.T @ y, maxiter=maxiter, itstat_options={"display": True, "period": 10}
)
Run the solver.
[5]:
print(f"Solving on {device_info()}\n")
x = solver.solve()
x = snp.clip(x, 0, 1)
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time L Residual
------------------------------------
0 5.27e+00 1.500e+01 2.068e+00
10 5.97e+01 1.500e+01 6.043e-01
20 1.09e+02 1.500e+01 2.108e-01
30 1.43e+02 1.500e+01 1.500e-01
40 1.70e+02 1.500e+01 1.425e-01
49 2.09e+02 1.500e+01 1.363e-01
Show the recovered image.
[6]:
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(15, 5))
kplt.imview(x_gt, cmap="Blues", title="Ground truth", ax=ax[0])
nc = n // 2
yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)
kplt.imview(
y, cmap="Blues", title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), ax=ax[1]
)
kplt.imview(x, cmap="Blues", title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), ax=ax[2])
fig.show()
Plot convergence statistics.
[7]:
kplt.plot(hist.Residual, ylog=True, title="PGM Residual", xlabel="Iteration", ylabel="Residual")
[7]:
<komplot.LinePlot at 0x762d004e3fe0>