Polar Total Variation Denoising (PDHG)¶
This example compares denoising via standard isotropic total variation (TV) regularization [50] [27] and a variant based on local polar coordinates, as described in [29]. It solves the denoising problem
\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x}
\|_2^2 + \lambda R(\mathbf{x}) \;,\]
where \(R\) is either the isotropic or polar TV regularizer, via the primal–dual hybrid gradient (PDHG) algorithm.
[1]:
import komplot as kplt
from xdesign import SiemensStar, discrete_phantom
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, metric
from scico.optimize import PDHG
from scico.util import device_info
kplt.config_notebook_plotting()
Create a ground truth image.
[2]:
N = 256 # image size
phantom = SiemensStar(16)
x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)
x_gt = x_gt / x_gt.max()
Add noise to create a noisy test image.
[3]:
σ = 0.75 # noise standard deviation
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise
Denoise with standard isotropic total variation.
[4]:
λ_std = 0.8e0
f = loss.SquaredL2Loss(y=y)
g_std = λ_std * functional.L21Norm()
# The append=0 option makes the results of horizontal and vertical finite
# differences the same shape, which is required for the L21Norm.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
tau, sigma = PDHG.estimate_parameters(C, ratio=20.0)
solver = PDHG(
f=f,
g=g_std,
C=C,
tau=tau,
sigma=sigma,
maxiter=200,
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
solver.solve()
hist_std = solver.itstat_object.history(transpose=True)
x_std = solver.x
print()
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl
-----------------------------------------------
0 9.98e-01 3.035e+04 2.247e+02 5.676e+01
10 1.75e+00 2.261e+04 5.950e+01 5.575e+00
20 1.83e+00 2.111e+04 2.715e+01 2.254e+00
30 1.91e+00 2.073e+04 1.271e+01 1.148e+00
40 1.99e+00 2.063e+04 6.003e+00 6.890e-01
50 2.06e+00 2.059e+04 2.906e+00 4.493e-01
60 2.13e+00 2.058e+04 1.457e+00 3.186e-01
70 2.21e+00 2.057e+04 7.698e-01 2.445e-01
80 2.29e+00 2.056e+04 4.446e-01 1.964e-01
90 2.36e+00 2.056e+04 2.850e-01 1.647e-01
100 2.42e+00 2.056e+04 2.029e-01 1.411e-01
110 2.48e+00 2.056e+04 1.535e-01 1.231e-01
120 2.55e+00 2.056e+04 1.216e-01 1.088e-01
130 2.62e+00 2.056e+04 1.010e-01 9.689e-02
140 2.69e+00 2.056e+04 8.517e-02 8.715e-02
150 2.75e+00 2.056e+04 7.335e-02 7.857e-02
160 2.80e+00 2.056e+04 6.411e-02 7.130e-02
170 2.84e+00 2.056e+04 5.632e-02 6.510e-02
180 2.89e+00 2.056e+04 4.851e-02 5.989e-02
190 2.94e+00 2.056e+04 4.309e-02 5.541e-02
199 2.99e+00 2.056e+04 3.818e-02 5.187e-02
Denoise with polar total variation for comparison.
[5]:
# Tune the weight to give the same data fidelty as the isotropic case.
λ_plr = 1.2e0
g_plr = λ_plr * functional.L1Norm()
G = linop.PolarGradient(input_shape=x_gt.shape)
D = linop.Diagonal(snp.array([0.3, 1.0]).reshape((2, 1, 1)), input_shape=G.shape[0])
C = D @ G
tau, sigma = PDHG.estimate_parameters(C, ratio=20.0)
solver = PDHG(
f=f,
g=g_plr,
C=C,
tau=tau,
sigma=sigma,
maxiter=200,
itstat_options={"display": True, "period": 10},
)
solver.solve()
hist_plr = solver.itstat_object.history(transpose=True)
x_plr = solver.x
print()
Iter Time Objective Prml Rsdl Dual Rsdl
-----------------------------------------------
0 1.18e-01 3.154e+04 2.247e+02 4.203e+01
10 3.87e-01 2.356e+04 6.907e+01 9.698e+00
20 4.89e-01 2.146e+04 3.139e+01 5.128e+00
30 5.83e-01 2.076e+04 1.556e+01 3.010e+00
40 6.74e-01 2.049e+04 8.294e+00 1.963e+00
50 7.60e-01 2.034e+04 4.836e+00 1.368e+00
60 8.45e-01 2.026e+04 3.027e+00 1.016e+00
70 9.26e-01 2.022e+04 2.030e+00 7.922e-01
80 9.97e-01 2.019e+04 1.445e+00 6.385e-01
90 1.05e+00 2.017e+04 1.070e+00 5.342e-01
100 1.11e+00 2.015e+04 8.161e-01 4.579e-01
110 1.16e+00 2.015e+04 6.448e-01 4.041e-01
120 1.22e+00 2.014e+04 5.232e-01 3.604e-01
130 1.28e+00 2.013e+04 4.355e-01 3.260e-01
140 1.34e+00 2.013e+04 3.698e-01 2.978e-01
150 1.39e+00 2.013e+04 3.170e-01 2.737e-01
160 1.45e+00 2.012e+04 2.854e-01 2.528e-01
170 1.52e+00 2.012e+04 2.552e-01 2.341e-01
180 1.59e+00 2.012e+04 2.231e-01 2.191e-01
190 1.66e+00 2.012e+04 2.040e-01 2.049e-01
199 1.73e+00 2.012e+04 1.870e-01 1.937e-01
Compute and print the data fidelity.
[6]:
for x, name in zip((x_std, x_plr), ("Isotropic", "Polar")):
df = f(x)
print(f"Data fidelity for {(name + ' TV'):12}: {df:.2e} SNR: {metric.snr(x_gt, x):5.2f} dB")
Data fidelity for Isotropic TV: 1.77e+04 SNR: 9.57 dB
Data fidelity for Polar TV : 1.78e+04 SNR: 11.17 dB
Plot results.
[7]:
plt_args = dict(norm=kplt.colors.Normalize(vmin=0, vmax=1.5))
fig, ax = kplt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))
kplt.imview(x_gt, title="Ground truth", ax=ax[0, 0], **plt_args)
kplt.imview(y, title="Noisy version", ax=ax[0, 1], **plt_args)
kplt.imview(x_std, title="Isotropic TV denoising", ax=ax[1, 0], **plt_args)
kplt.imview(x_plr, title="Polar TV denoising", ax=ax[1, 1], **plt_args)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units"
)
fig.suptitle("Denoising comparison")
fig.show()
# zoomed version
fig, ax = kplt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))
kplt.imview(x_gt, title="Ground truth", ax=ax[0, 0], **plt_args)
kplt.imview(y, title="Noisy version", ax=ax[0, 1], **plt_args)
kplt.imview(x_std, title="Isotropic TV denoising", ax=ax[1, 0], **plt_args)
kplt.imview(x_plr, title="Polar TV denoising", ax=ax[1, 1], **plt_args)
ax[0, 0].set_xlim(N // 4, N // 4 + N // 2)
ax[0, 0].set_ylim(N // 4, N // 4 + N // 2)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units"
)
fig.suptitle("Denoising comparison (zoomed)")
fig.show()
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=False, figsize=(20, 5))
kplt.plot(
snp.array((hist_std.Objective, hist_plr.Objective)).T,
ylog=True,
title="Objective function",
xlabel="Iteration",
legend=("Standard", "Polar"),
ax=ax[0],
)
kplt.plot(
snp.array((hist_std.Prml_Rsdl, hist_plr.Prml_Rsdl)).T,
ylog=True,
title="Primal residual",
xlabel="Iteration",
legend=("Standard", "Polar"),
ax=ax[1],
)
kplt.plot(
snp.array((hist_std.Dual_Rsdl, hist_plr.Dual_Rsdl)).T,
ylog=True,
title="Dual residual",
xlabel="Iteration",
legend=("Standard", "Polar"),
ax=ax[2],
)
fig.show()