PPP (with DnCNN) Image Deconvolution (Proximal ADMM Solver)#

This example demonstrates the solution of an image deconvolution problem using a proximal ADMM variant of the Plug-and-Play Priors (PPP) algorithm [50] with the DnCNN [58] denoiser.

[1]:
import numpy as np

from xdesign import Foam, discrete_phantom

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot, random
from scico.optimize import ProximalADMM
from scico.util import device_info
plot.config_notebook_plotting()

Create a ground truth image.

[2]:
np.random.seed(1234)
N = 512  # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = snp.array(x_gt)  # convert to jax array

Set up forward operator \(A\) and test signal consisting of blurred signal with additive Gaussian noise.

[3]:
n = 5  # convolution kernel size
σ = 20.0 / 255  # noise level

psf = snp.ones((n, n)) / (n * n)
A = linop.Convolve(h=psf, input_shape=x_gt.shape)

Ax = A(x_gt)  # blurred image
noise, key = random.randn(Ax.shape)
y = Ax + σ * noise

Set up the problem to be solved. We want to minimize the functional

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + R(\mathbf{x}) \;\]
where \(R(\cdot)\) is a pseudo-functional having the DnCNN denoiser as its proximal operator. A slightly unusual variable splitting is used,
including setting the \(f\) functional to the \(R(\cdot)\) term and the \(g\) functional to the data fidelity term to allow the use of proximal ADMM, which avoids the need for conjugate gradient sub-iterations in the solver steps.
[4]:
f = functional.DnCNN(variant="17M")
g = loss.SquaredL2Loss(y=y)

Set up proximal ADMM solver.

[5]:
ρ = 0.2  # ADMM penalty parameter
maxiter = 10  # number of proximal ADMM iterations
mu, nu = ProximalADMM.estimate_parameters(A)

solver = ProximalADMM(
    f=f,
    g=g,
    A=A,
    rho=ρ,
    mu=mu,
    nu=nu,
    x0=A.T @ y,
    maxiter=maxiter,
    itstat_options={"display": True},
)

Run the solver.

[6]:
print(f"Solving on {device_info()}\n")
x = solver.solve()
x = snp.clip(x, 0, 1)
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Prml Rsdl  Dual Rsdl
------------------------------------
   0  1.87e+00  4.320e+01  2.631e+02
   1  2.09e+00  2.946e+01  5.875e+00
   2  2.29e+00  2.433e+01  4.868e+00
   3  2.49e+00  2.039e+01  4.079e+00
   4  2.69e+00  1.701e+01  3.403e+00
   5  2.90e+00  1.412e+01  2.825e+00
   6  3.10e+00  1.172e+01  2.345e+00
   7  3.30e+00  9.744e+00  1.950e+00
   8  3.50e+00  8.121e+00  1.625e+00
   9  3.70e+00  6.782e+00  1.357e+00

Show the recovered image.

[7]:
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0])
nc = n // 2
yc = snp.clip(y[nc:-nc, nc:-nc], 0, 1)
plot.imview(y, title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), fig=fig, ax=ax[1])
plot.imview(x, title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, x), fig=fig, ax=ax[2])
fig.show()
../_images/examples_deconv_ppp_dncnn_padmm_13_0.png

Plot convergence statistics.

[8]:
plot.plot(
    snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
    ptyp="semilogy",
    title="Residuals",
    xlbl="Iteration",
    lgnd=("Primal", "Dual"),
)
../_images/examples_deconv_ppp_dncnn_padmm_15_0.png