PPP (with BM3D) CT Reconstruction (ADMM with Fast SVMBIR Prox)#

This example demonstrates solution of a tomographic reconstruction problem using the Plug-and-Play Priors framework [50], using BM3D [16] as a denoiser and SVMBIR [48] for tomographic projection.

There are two versions of this example, solving the same problem in two different ways. This version uses the data fidelity term as one of the ADMM \(g\) functionals so that the optimization with respect to the data fidelity is able to exploit the internal prox of the SVMBIRExtendedLoss and SVMBIRSquaredL2Loss functionals. The other version solves the ADMM subproblem corresponding to the data fidelity term via CG.

Two ways of exploiting the SVMBIR internal prox are explored in this example: 1. Using the SVMBIRSquaredL2Loss together with the BM3D pseudo-functional and a non-negative indicator function, and 2. Using the SVMBIRExtendedLoss, which includes a non-negativity constraint, together with the BM3D pseudo-functional.

[1]:
import numpy as np

import matplotlib.pyplot as plt
import svmbir
from matplotlib.ticker import MaxNLocator
from xdesign import Foam, discrete_phantom

import scico.numpy as snp
from scico import metric, plot
from scico.functional import BM3D, NonNegativeIndicator
from scico.linop import Diagonal, Identity
from scico.linop.xray.svmbir import (
    SVMBIRExtendedLoss,
    SVMBIRSquaredL2Loss,
    XRayTransform,
)
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()

Generate a ground truth image.

[2]:
N = 256  # image size
density = 0.025  # attenuation density of the image
np.random.seed(1234)
x_gt = discrete_phantom(Foam(size_range=[0.05, 0.02], gap=0.02, porosity=0.3), size=N - 10)
x_gt = x_gt / np.max(x_gt) * density
x_gt = np.pad(x_gt, 5)
x_gt[x_gt < 0] = 0

Generate tomographic projector and sinogram.

[3]:
num_angles = int(N / 2)
num_channels = N
angles = snp.linspace(0, snp.pi, num_angles, endpoint=False, dtype=snp.float32)
A = XRayTransform(x_gt.shape, angles, num_channels)
sino = A @ x_gt

Impose Poisson noise on sinogram. Higher max_intensity means less noise.

[4]:
max_intensity = 2000
expected_counts = max_intensity * np.exp(-sino)
noisy_counts = np.random.poisson(expected_counts).astype(np.float32)
noisy_counts[noisy_counts == 0] = 1  # deal with 0s
y = -np.log(noisy_counts / max_intensity)

Reconstruct using default prior of SVMBIR [48].

[5]:
weights = svmbir.calc_weights(y, weight_type="transmission")

x_mrf = svmbir.recon(
    np.array(y[:, np.newaxis]),
    np.array(angles),
    weights=weights[:, np.newaxis],
    num_rows=N,
    num_cols=N,
    positivity=True,
    verbose=0,
)[0]

Convert numpy arrays to jax arrays.

[6]:
y = snp.array(y)
x0 = snp.array(x_mrf)
weights = snp.array(weights)

Set problem parameters and BM3D pseudo-functional.

[7]:
ρ = 10  # ADMM penalty parameter
σ = density * 0.26  # denoiser sigma
g0 = σ * ρ * BM3D()

Set up problem using SVMBIRSquaredL2Loss and NonNegativeIndicator.

[8]:
f_l2loss = SVMBIRSquaredL2Loss(
    y=y, A=A, W=Diagonal(weights), scale=0.5, prox_kwargs={"maxiter": 5, "ctol": 0.0}
)
g1 = NonNegativeIndicator()

solver_l2loss = ADMM(
    f=None,
    g_list=[f_l2loss, g0, g1],
    C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape), Identity(x_mrf.shape)],
    rho_list=[ρ, ρ, ρ],
    x0=x0,
    maxiter=20,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
    itstat_options={"display": True, "period": 5},
)

Run the ADMM solver.

[9]:
print(f"Solving on {device_info()}\n")
x_l2loss = solver_l2loss.solve()
hist_l2loss = solver_l2loss.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Prml Rsdl  Dual Rsdl  CG It  CG Res
------------------------------------------------------
   0  3.67e+00  5.375e-01  5.375e-01      0  0.000e+00
   1  7.05e+00  3.687e-01  3.230e-01      1  3.913e-09
   2  1.07e+01  2.416e-01  2.150e-01      1  2.928e-09
   3  1.32e+01  1.573e-01  1.827e-01      1  2.418e-09
   4  1.66e+01  1.189e-01  1.367e-01      1  1.673e-09
   5  2.02e+01  1.015e-01  9.357e-02      1  2.187e-09
   6  2.40e+01  8.631e-02  6.378e-02      1  6.341e-10
   7  2.77e+01  6.855e-02  5.047e-02      1  5.304e-10
   8  3.09e+01  5.318e-02  4.421e-02      1  5.568e-10
   9  3.40e+01  4.378e-02  3.729e-02      1  3.857e-10
  10  3.75e+01  3.905e-02  2.961e-02      1  3.466e-10
  11  4.10e+01  3.554e-02  2.390e-02      1  2.640e-10
  12  4.45e+01  3.225e-02  2.117e-02      1  4.761e-10
  13  4.84e+01  2.893e-02  1.961e-02      1  1.834e-10
  14  5.15e+01  2.672e-02  1.795e-02      1  1.972e-10
  15  5.53e+01  2.457e-02  1.615e-02      1  3.801e-10
  16  5.83e+01  2.318e-02  1.501e-02      1  1.520e-10
  17  6.19e+01  2.172e-02  1.407e-02      1  1.273e-10
  18  6.52e+01  2.067e-02  1.363e-02      1  1.397e-10
  19  6.89e+01  1.981e-02  1.360e-02      1  1.330e-10

Set up problem using SVMBIRExtendedLoss, without need for NonNegativeIndicator.

[10]:
f_extloss = SVMBIRExtendedLoss(
    y=y,
    A=A,
    W=Diagonal(weights),
    scale=0.5,
    positivity=True,
    prox_kwargs={"maxiter": 5, "ctol": 0.0},
)

solver_extloss = ADMM(
    f=None,
    g_list=[f_extloss, g0],
    C_list=[Identity(x_mrf.shape), Identity(x_mrf.shape)],
    rho_list=[ρ, ρ],
    x0=x0,
    maxiter=20,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
    itstat_options={"display": True, "period": 5},
)

Run the ADMM solver.

[11]:
print()
x_extloss = solver_extloss.solve()
hist_extloss = solver_extloss.itstat_object.history(transpose=True)

Iter  Time      Prml Rsdl  Dual Rsdl  CG It  CG Res
------------------------------------------------------
   0  3.01e+00  4.720e-01  4.720e-01      0  0.000e+00
   1  6.68e+00  3.471e-01  2.464e-01      1  5.184e-09
   2  9.61e+00  1.689e-01  2.116e-01      1  9.142e-10
   3  1.36e+01  1.023e-01  1.473e-01      1  4.484e-10
   4  1.73e+01  8.577e-02  8.309e-02      1  2.231e-10
   5  2.10e+01  6.649e-02  4.819e-02      1  1.205e-10
   6  2.46e+01  4.644e-02  3.784e-02      1  6.679e-10
   7  2.74e+01  3.438e-02  3.069e-02      1  8.120e-11
   8  3.09e+01  2.915e-02  2.295e-02      1  3.865e-10
   9  3.37e+01  2.598e-02  1.785e-02      1  4.350e-11
  10  3.71e+01  2.280e-02  1.480e-02      1  2.567e-10
  11  4.03e+01  2.039e-02  1.298e-02      1  2.823e-11
  12  4.36e+01  1.890e-02  1.154e-02      1  1.974e-10
  13  4.71e+01  1.743e-02  1.021e-02      1  1.803e-10
  14  5.01e+01  1.644e-02  9.463e-03      1  1.599e-10
  15  5.33e+01  1.556e-02  8.859e-03      1  1.540e-10
  16  5.61e+01  1.488e-02  8.616e-03      1  1.457e-10
  17  5.93e+01  1.425e-02  8.511e-03      1  1.452e-10
  18  6.21e+01  1.373e-02  8.126e-03      1  1.990e-11
  19  6.53e+01  1.316e-02  7.722e-03      1  1.384e-10

Show the recovered images.

[12]:
norm = plot.matplotlib.colors.Normalize(vmin=-0.1 * density, vmax=1.2 * density)
fig, ax = plt.subplots(2, 2, figsize=(15, 15))
plot.imview(img=x_gt, title="Ground Truth Image", cbar=True, fig=fig, ax=ax[0, 0], norm=norm)
plot.imview(
    img=x_mrf,
    title=f"MRF (PSNR: {metric.psnr(x_gt, x_mrf):.2f} dB)",
    cbar=True,
    fig=fig,
    ax=ax[0, 1],
    norm=norm,
)
plot.imview(
    img=x_l2loss,
    title=f"SquaredL2Loss + non-negativity (PSNR: {metric.psnr(x_gt, x_l2loss):.2f} dB)",
    cbar=True,
    fig=fig,
    ax=ax[1, 0],
    norm=norm,
)
plot.imview(
    img=x_extloss,
    title=f"ExtendedLoss (PSNR: {metric.psnr(x_gt, x_extloss):.2f} dB)",
    cbar=True,
    fig=fig,
    ax=ax[1, 1],
    norm=norm,
)
fig.show()
../_images/examples_ct_svmbir_ppp_bm3d_admm_prox_23_0.png

Plot convergence statistics.

[13]:
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
plot.plot(
    snp.vstack((hist_l2loss.Prml_Rsdl, hist_l2loss.Dual_Rsdl)).T,
    ptyp="semilogy",
    title="Residuals (SquaredL2Loss + non-negativity)",
    xlbl="Iteration",
    lgnd=("Primal", "Dual"),
    fig=fig,
    ax=ax[0],
)
ax[0].set_ylim([5e-3, 5e0])
ax[0].xaxis.set_major_locator(MaxNLocator(integer=True))
plot.plot(
    snp.vstack((hist_extloss.Prml_Rsdl, hist_extloss.Dual_Rsdl)).T,
    ptyp="semilogy",
    title="Residuals (ExtendedLoss)",
    xlbl="Iteration",
    lgnd=("Primal", "Dual"),
    fig=fig,
    ax=ax[1],
)
ax[1].set_ylim([5e-3, 5e0])
ax[1].xaxis.set_major_locator(MaxNLocator(integer=True))
fig.show()
../_images/examples_ct_svmbir_ppp_bm3d_admm_prox_25_0.png