Total Variation Denoising with Constraint (APGM)#

This example demonstrates the solution of the isotropic total variation (TV) denoising problem

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - \mathbf{x} \|_2^2 + \lambda R(\mathbf{x}) + \iota_C(\mathbf{x}) \;,\]

where \(R\) is a TV regularizer, \(\iota_C(\cdot)\) is the indicator function of constraint set \(C\), and \(C = \{ \mathbf{x} \, | \, x_i \in [0, 1] \}\), i.e. the set of vectors with components constrained to be in the interval \([0, 1]\). The problem is solved seperately with \(R\) taken as isotropic and anisotropic TV regularization

The solution via APGM is based on the approach in [9], which involves constructing a dual for the constrained denoising problem. The APGM solution minimizes the resulting dual. In this case, switching between the two regularizers corresponds to switching between two different projectors.

[1]:
from typing import Callable, Optional, Union

import jax.numpy as jnp

from xdesign import SiemensStar, discrete_phantom

import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, operator, plot
from scico.numpy import Array, BlockArray
from scico.optimize.pgm import AcceleratedPGM, RobustLineSearchStepSize
from scico.util import device_info
plot.config_notebook_plotting()

Create a ground truth image.

[2]:
N = 256  # image size
phantom = SiemensStar(16)
x_gt = snp.pad(discrete_phantom(phantom, N - 16), 8)
x_gt = x_gt / x_gt.max()

Add noise to create a noisy test image.

[3]:
σ = 0.75  # noise standard deviation
noise, key = scico.random.randn(x_gt.shape, seed=0)
y = x_gt + σ * noise

Define finite difference operator and adjoint.

[4]:
# The append=0 option appends 0 to the input along the axis
# prior to performing the difference to make the results of
# horizontal and vertical finite differences the same shape.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
A = C.adj

Define a zero array as initial estimate.

[5]:
x0 = jnp.zeros(C(y).shape)

Define the dual of the total variation denoising problem.

[6]:
class DualTVLoss(loss.Loss):
    def __init__(
        self,
        y: Union[Array, BlockArray],
        A: Optional[Union[Callable, operator.Operator]] = None,
        lmbda: float = 0.5,
    ):
        self.functional = functional.SquaredL2Norm()
        super().__init__(y=y, A=A, scale=1.0)
        self.lmbda = lmbda

    def __call__(self, x: Union[Array, BlockArray]) -> float:
        xint = self.y - self.lmbda * self.A(x)
        return -1.0 * self.functional(xint - jnp.clip(xint, 0.0, 1.0)) + self.functional(xint)

Denoise with isotropic total variation. Define projector for isotropic total variation.

[7]:
# Evaluation of functional set to zero.
class IsoProjector(functional.Functional):
    has_eval = True
    has_prox = True

    def __call__(self, x: Union[Array, BlockArray]) -> float:
        return 0.0

    def prox(self, v: Array, lam: float, **kwargs) -> Array:
        norm_v_ptp = jnp.sqrt(jnp.sum(jnp.abs(v) ** 2, axis=0))

        x_out = v / jnp.maximum(jnp.ones(v.shape), norm_v_ptp)
        out1 = v[0, :, -1] / jnp.maximum(jnp.ones(v[0, :, -1].shape), jnp.abs(v[0, :, -1]))
        x_out = x_out.at[0, :, -1].set(out1)
        out2 = v[1, -1, :] / jnp.maximum(jnp.ones(v[1, -1, :].shape), jnp.abs(v[1, -1, :]))
        x_out = x_out.at[1, -1, :].set(out2)

        return x_out

Use RobustLineSearchStepSize object and set up AcceleratedPGM solver object. Run the solver.

[8]:
reg_weight_iso = 1.4e0
f_iso = DualTVLoss(y=y, A=A, lmbda=reg_weight_iso)
g_iso = IsoProjector()

solver_iso = AcceleratedPGM(
    f=f_iso,
    g=g_iso,
    L0=16.0 * f_iso.lmbda**2,
    x0=x0,
    maxiter=100,
    itstat_options={"display": True, "period": 10},
    step_size=RobustLineSearchStepSize(),
)

# Run the solver.
print(f"Solving on {device_info()}\n")
x = solver_iso.solve()
hist_iso = solver_iso.itstat_object.history(transpose=True)
# Project to constraint set.
x_iso = jnp.clip(y - f_iso.lmbda * f_iso.A(x), 0.0, 1.0)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Objective  L          Residual
-----------------------------------------------
   0  4.24e+00  3.141e+04  2.822e+01  1.835e+01
  10  4.56e+00  1.749e+04  1.968e+01  9.306e+00
  20  4.78e+00  1.513e+04  2.745e+01  5.370e+00
  30  4.98e+00  1.472e+04  1.914e+01  3.691e+00
  40  5.17e+00  1.461e+04  2.670e+01  2.082e+00
  50  5.35e+00  1.456e+04  1.862e+01  1.982e+00
  60  5.58e+00  1.453e+04  2.597e+01  1.394e+00
  70  5.75e+00  1.451e+04  1.811e+01  1.484e+00
  80  5.93e+00  1.450e+04  2.526e+01  1.130e+00
  90  6.09e+00  1.449e+04  1.761e+01  1.245e+00
  99  6.26e+00  1.449e+04  2.729e+01  9.203e-01

Denoise with anisotropic total variation for comparison. Define projector for anisotropic total variation.

[9]:
# Evaluation of functional set to zero.
class AnisoProjector(functional.Functional):
    has_eval = True
    has_prox = True

    def __call__(self, x: Union[Array, BlockArray]) -> float:
        return 0.0

    def prox(self, v: Array, lam: float, **kwargs) -> Array:
        return v / jnp.maximum(jnp.ones(v.shape), jnp.abs(v))

Use RobustLineSearchStepSize object and set up AcceleratedPGM solver object. Weight was tuned to give the same data fidelity as the isotropic case. Run the solver.

[10]:
reg_weight_aniso = 1.2e0
f = DualTVLoss(y=y, A=A, lmbda=reg_weight_aniso)
g = AnisoProjector()

solver = AcceleratedPGM(
    f=f,
    g=g,
    L0=16.0 * f.lmbda**2,
    x0=x0,
    maxiter=100,
    itstat_options={"display": True, "period": 10},
    step_size=RobustLineSearchStepSize(),
)

# Run the solver.
print()
x = solver.solve()
# Project to constraint set.
x_aniso = jnp.clip(y - f.lmbda * f.A(x), 0.0, 1.0)

Iter  Time      Objective  L          Residual
-----------------------------------------------
   0  3.00e-01  3.141e+04  2.074e+01  2.141e+01
  10  5.07e-01  1.753e+04  1.446e+01  1.073e+01
  20  6.86e-01  1.518e+04  2.017e+01  6.434e+00
  30  8.52e-01  1.475e+04  1.406e+01  4.318e+00
  40  1.03e+00  1.465e+04  1.962e+01  2.370e+00
  50  1.20e+00  1.460e+04  1.368e+01  2.274e+00
  60  1.38e+00  1.457e+04  1.908e+01  1.588e+00
  70  1.56e+00  1.456e+04  1.330e+01  1.674e+00
  80  1.74e+00  1.455e+04  1.856e+01  1.270e+00
  90  1.91e+00  1.454e+04  1.294e+01  1.393e+00
  99  2.10e+00  1.453e+04  2.005e+01  1.017e+00

Compute the data fidelity.

[11]:
df = hist_iso.Objective[-1]
print(f"\nData fidelity for isotropic TV was {df:.2e}")
hist = solver.itstat_object.history(transpose=True)
df = hist.Objective[-1]
print(f"Data fidelity for anisotropic TV was {df:.2e}")

Data fidelity for isotropic TV was 1.45e+04
Data fidelity for anisotropic TV was 1.45e+04

Plot results.

[12]:
plt_args = dict(norm=plot.matplotlib.colors.Normalize(vmin=0, vmax=1.5))
fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))
plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args)
plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args)
plot.imview(x_iso, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args)
plot.imview(x_aniso, title="Anisotropic TV denoising", fig=fig, ax=ax[1, 1], **plt_args)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
    ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units"
)
fig.suptitle("Denoising comparison")
fig.show()

# zoomed version
fig, ax = plot.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))
plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0], **plt_args)
plot.imview(y, title="Noisy version", fig=fig, ax=ax[0, 1], **plt_args)
plot.imview(x_iso, title="Isotropic TV denoising", fig=fig, ax=ax[1, 0], **plt_args)
plot.imview(x_aniso, title="Anisotropic TV denoising", fig=fig, ax=ax[1, 1], **plt_args)
ax[0, 0].set_xlim(N // 4, N // 4 + N // 2)
ax[0, 0].set_ylim(N // 4, N // 4 + N // 2)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
    ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="Arbitrary Units"
)
fig.suptitle("Denoising comparison (zoomed)")
fig.show()
../_images/examples_denoise_tv_apgm_23_0.png
../_images/examples_denoise_tv_apgm_23_1.png