ℓ1 Total Variation Denoising¶
This example demonstrates impulse noise removal via ℓ1 total variation [2] [21] (Sec. 2.4.4) (i.e. total variation regularization with an ℓ1 data fidelity term), minimizing the functional
\[\mathrm{argmin}_{\mathbf{x}} \; \| \mathbf{y} - \mathbf{x}
\|_1 + \lambda \| C \mathbf{x} \|_{2,1} \;,\]
where \(\mathbf{y}\) is the noisy image, \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the denoised image.
[1]:
import komplot as kplt
from xdesign import SiemensStar, discrete_phantom
import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.examples import spnoise
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
from scipy.ndimage import median_filter
kplt.config_notebook_plotting()
Create a ground truth image and impose salt & pepper noise to create a noisy test image.
[2]:
N = 256 # image size
phantom = SiemensStar(16)
x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)
x_gt = 0.5 * x_gt / x_gt.max()
y = spnoise(x_gt, 0.5)
Denoise with median filtering.
[3]:
x_med = median_filter(y, size=(5, 5))
Denoise with ℓ1 total variation.
[4]:
λ = 1.5e0
g_loss = loss.Loss(y=y, f=functional.L1Norm())
g_tv = λ * functional.L21Norm()
# The append=0 option makes the results of horizontal and vertical finite
# differences the same shape, which is required for the L21Norm.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
solver = ADMM(
f=None,
g_list=[g_loss, g_tv],
C_list=[linop.Identity(input_shape=y.shape), C],
rho_list=[5e0, 5e0],
x0=y,
maxiter=100,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}),
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
x_tv = solver.solve()
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 2.23e+00 4.235e+04 1.410e+02 6.768e+02 0 0.000e+00
10 3.43e+00 1.904e+04 1.223e+01 3.909e+01 8 8.948e-04
20 3.70e+00 1.903e+04 2.028e+00 7.437e+00 5 7.223e-04
30 3.90e+00 1.904e+04 8.546e-01 2.779e+00 4 5.360e-04
40 4.08e+00 1.904e+04 4.699e-01 1.551e+00 3 6.756e-04
50 4.25e+00 1.904e+04 3.017e-01 1.053e+00 2 9.842e-04
60 4.40e+00 1.904e+04 2.261e-01 5.246e-01 2 7.396e-04
70 4.51e+00 1.904e+04 1.630e-01 4.010e-01 1 8.978e-04
80 4.62e+00 1.904e+04 1.352e-01 1.834e-01 1 9.490e-04
90 4.73e+00 1.904e+04 1.158e-01 1.554e-01 1 8.789e-04
99 4.82e+00 1.904e+04 1.033e-01 1.293e-01 1 7.632e-04
Plot results.
[5]:
plt_args = dict(norm=kplt.colors.Normalize(vmin=0, vmax=1.0))
fig, ax = kplt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(13, 12))
kplt.imview(x_gt, title="Ground truth", ax=ax[0, 0], **plt_args)
kplt.imview(y, title="Noisy image", ax=ax[0, 1], **plt_args)
kplt.imview(
x_med,
title=f"Median filtering: {metric.psnr(x_gt, x_med):.2f} (dB)",
ax=ax[1, 0],
**plt_args,
)
kplt.imview(
x_tv,
title=f"ℓ1-TV denoising: {metric.psnr(x_gt, x_tv):.2f} (dB)",
ax=ax[1, 1],
**plt_args,
)
fig.show()
Plot convergence statistics.
[6]:
fig, ax = kplt.subplots(nrows=1, ncols=2, figsize=(12, 5))
kplt.plot(
hist.Objective,
title="Objective function",
xlabel="Iteration",
ylabel="Functional value",
ax=ax[0],
)
kplt.plot(
snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
ylog=True,
title="Residuals",
xlabel="Iteration",
legend=("Primal", "Dual"),
ax=ax[1],
)
fig.show()