Image Deconvolution with TV Regularization (ADMM Solver)¶
This example demonstrates the solution of an image deconvolution problem with isotropic total variation (TV) regularization
where \(C\) is a convolution operator, \(\mathbf{y}\) is the blurred image, \(D\) is a 2D finite fifference operator, and \(\mathbf{x}\) is the deconvolved image.
In this example the problem is solved via standard ADMM, while proximal ADMM is used in a companion example.
[1]:
import komplot as kplt
from xdesign import SiemensStar, discrete_phantom
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, metric
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
kplt.config_notebook_plotting()
Create a ground truth image.
[2]:
phantom = SiemensStar(32)
N = 256 # image size
x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)
Set up the forward operator and create a test signal consisting of a blurred signal with additive Gaussian noise.
[3]:
n = 5 # convolution kernel size
σ = 20.0 / 255 # noise level
psf = snp.ones((n, n)) / (n * n)
C = linop.Convolve(h=psf, input_shape=x_gt.shape)
Cx = C(x_gt) # blurred image
noise, key = scico.random.randn(Cx.shape, seed=0)
y = Cx + σ * noise
Set up the problem to be solved. We want to minimize the functional
where \(C\) is the convolution operator and \(D\) is a finite difference operator. This problem can be expressed as
which is easily written in the form of a standard ADMM problem.
This is simpler splitting than that used in the companion example, but it requires the use conjugate gradient sub-iterations to solve the ADMM step associated with the data fidelity term.
[4]:
f = loss.SquaredL2Loss(y=y, A=C)
# Penalty parameters must be accounted for in the gi functions, not as
# additional inputs.
λ = 2.1e-2 # ℓ2,1 norm regularization parameter
g = λ * functional.L21Norm()
# The append=0 option makes the results of horizontal and vertical
# finite differences the same shape, which is required for the L21Norm,
# which is used so that g(Cx) corresponds to isotropic TV.
D = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
Set up an ADMM solver object.
[5]:
ρ = 1.0e-1 # ADMM penalty parameter
maxiter = 50 # number of ADMM iterations
solver = ADMM(
f=f,
g_list=[g],
C_list=[D],
rho_list=[ρ],
x0=C.adj(y),
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(),
itstat_options={"display": True, "period": 10},
)
Run the solver.
[6]:
print(f"Solving on {device_info()}\n")
x = solver.solve()
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 2.01e+00 2.237e+02 1.075e+01 1.256e+00 9 7.026e-05
10 3.14e+00 3.320e+02 3.950e-01 2.751e-01 4 8.945e-05
20 3.36e+00 3.328e+02 1.800e-01 1.132e-01 3 6.900e-05
30 3.52e+00 3.331e+02 1.097e-01 6.405e-02 2 7.460e-05
40 3.67e+00 3.333e+02 7.689e-02 3.823e-02 2 5.013e-05
49 3.78e+00 3.334e+02 5.960e-02 2.826e-02 1 8.888e-05
Show the recovered image.
[7]:
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(15, 5))
kplt.imview(x_gt, cmap="Blues", title="Ground truth", ax=ax[0])
nc = n // 2
yc = y[nc:-nc, nc:-nc]
kplt.imview(
y, cmap="Blues", title="Blurred, noisy image: %.2f (dB)" % metric.psnr(x_gt, yc), ax=ax[1]
)
kplt.imview(
solver.x,
cmap="Blues",
title="Deconvolved image: %.2f (dB)" % metric.psnr(x_gt, solver.x),
ax=ax[2],
)
fig.show()
Plot convergence statistics.
[8]:
fig, ax = kplt.subplots(nrows=1, ncols=2, figsize=(12, 5))
kplt.plot(
hist.Objective,
title="Objective function",
xlabel="Iteration",
ylabel="Functional value",
ax=ax[0],
)
kplt.plot(
snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
ylog=True,
title="Residuals",
xlabel="Iteration",
legend=("Primal", "Dual"),
ax=ax[1],
)
fig.show()