ℓ1 Total Variation Denoising

This example demonstrates impulse noise removal via ℓ1 total variation [2] [21] (Sec. 2.4.4) (i.e. total variation regularization with an ℓ1 data fidelity term), minimizing the functional

\[\mathrm{argmin}_{\mathbf{x}} \; \| \mathbf{y} - \mathbf{x} \|_1 + \lambda \| C \mathbf{x} \|_{2,1} \;,\]

where \(\mathbf{y}\) is the noisy image, \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the denoised image.

[1]:
import komplot as kplt
from xdesign import SiemensStar, discrete_phantom

import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.examples import spnoise
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
from scipy.ndimage import median_filter
kplt.config_notebook_plotting()

Create a ground truth image and impose salt & pepper noise to create a noisy test image.

[2]:
N = 256  # image size
phantom = SiemensStar(16)
x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)
x_gt = 0.5 * x_gt / x_gt.max()
y = spnoise(x_gt, 0.5)

Denoise with median filtering.

[3]:
x_med = median_filter(y, size=(5, 5))

Denoise with ℓ1 total variation.

[4]:
λ = 1.5e0
g_loss = loss.Loss(y=y, f=functional.L1Norm())
g_tv = λ * functional.L21Norm()
# The append=0 option makes the results of horizontal and vertical finite
# differences the same shape, which is required for the L21Norm.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)

solver = ADMM(
    f=None,
    g_list=[g_loss, g_tv],
    C_list=[linop.Identity(input_shape=y.shape), C],
    rho_list=[5e0, 5e0],
    x0=y,
    maxiter=100,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 20}),
    itstat_options={"display": True, "period": 10},
)

print(f"Solving on {device_info()}\n")
x_tv = solver.solve()
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Objective  Prml Rsdl  Dual Rsdl  CG It  CG Res
-----------------------------------------------------------------
   0  2.23e+00  4.235e+04  1.410e+02  6.768e+02      0  0.000e+00
  10  3.43e+00  1.904e+04  1.223e+01  3.909e+01      8  8.948e-04
  20  3.70e+00  1.903e+04  2.028e+00  7.437e+00      5  7.223e-04
  30  3.90e+00  1.904e+04  8.546e-01  2.779e+00      4  5.360e-04
  40  4.08e+00  1.904e+04  4.699e-01  1.551e+00      3  6.756e-04
  50  4.25e+00  1.904e+04  3.017e-01  1.053e+00      2  9.842e-04
  60  4.40e+00  1.904e+04  2.261e-01  5.246e-01      2  7.396e-04
  70  4.51e+00  1.904e+04  1.630e-01  4.010e-01      1  8.978e-04
  80  4.62e+00  1.904e+04  1.352e-01  1.834e-01      1  9.490e-04
  90  4.73e+00  1.904e+04  1.158e-01  1.554e-01      1  8.789e-04
  99  4.82e+00  1.904e+04  1.033e-01  1.293e-01      1  7.632e-04

Plot results.

[5]:
plt_args = dict(norm=kplt.colors.Normalize(vmin=0, vmax=1.0))
fig, ax = kplt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(13, 12))
kplt.imview(x_gt, title="Ground truth", ax=ax[0, 0], **plt_args)
kplt.imview(y, title="Noisy image", ax=ax[0, 1], **plt_args)
kplt.imview(
    x_med,
    title=f"Median filtering: {metric.psnr(x_gt, x_med):.2f} (dB)",
    ax=ax[1, 0],
    **plt_args,
)
kplt.imview(
    x_tv,
    title=f"ℓ1-TV denoising: {metric.psnr(x_gt, x_tv):.2f} (dB)",
    ax=ax[1, 1],
    **plt_args,
)
fig.show()
../_images/examples_denoise_l1tv_admm_9_0.png

Plot convergence statistics.

[6]:
fig, ax = kplt.subplots(nrows=1, ncols=2, figsize=(12, 5))
kplt.plot(
    hist.Objective,
    title="Objective function",
    xlabel="Iteration",
    ylabel="Functional value",
    ax=ax[0],
)
kplt.plot(
    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
    ylog=True,
    title="Residuals",
    xlabel="Iteration",
    legend=("Primal", "Dual"),
    ax=ax[1],
)
fig.show()
../_images/examples_denoise_l1tv_admm_11_0.png