3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)

This example demonstrates solution of a sparse-view, 3D CT reconstruction problem with isotropic total variation (TV) regularization

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,\]

where \(C\) is the X-ray transform (the CT forward projection operator), \(\mathbf{y}\) is the sinogram, \(D\) is a 3D finite difference operator, and \(\mathbf{x}\) is the reconstructed image.

In this example the problem is solved via proximal ADMM, while standard ADMM is used in a companion example.

[1]:
import numpy as np

import komplot as kplt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.examples import create_tangle_phantom
from scico.linop.xray.astra import XRayTransform3D, angle_to_vector
from scico.optimize import ProximalADMM
from scico.util import device_info
kplt.config_notebook_plotting()

Create a ground truth image and projector.

[2]:
Nx = 128
Ny = 256
Nz = 64

tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))

n_projection = 10  # number of projections
angles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles
det_spacing = [1.0, 1.0]
det_count = [Nz, max(Nx, Ny)]
vectors = angle_to_vector(det_spacing, angles)

# It would have been more straightforward to use the det_spacing and angles keywords
# in this case (since vectors is just computed directly from these two quantities), but
# the more general form is used here as a demonstration.
C = XRayTransform3D(tangle.shape, det_count=det_count, vectors=vectors)  # CT projection operator
y = C @ tangle  # sinogram

Set up problem and solver. We want to minimize the functional

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - C \mathbf{x} \|_2^2 + \lambda \| D \mathbf{x} \|_{2,1} \;,\]

where \(C\) is the X-ray transform and \(D\) is a finite difference operator. This problem can be expressed as

\[\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 + \lambda \| \mathbf{z}_1 \|_{2,1} \;\; \text{such that} \;\; \mathbf{z}_0 = C \mathbf{x} \;\; \text{and} \;\; \mathbf{z}_1 = D \mathbf{x} \;,\]

which can be written in the form of a standard ADMM problem

\[\mathrm{argmin}_{\mathbf{x}, \mathbf{z}} \; f(\mathbf{x}) + g(\mathbf{z}) \;\; \text{such that} \;\; A \mathbf{x} + B \mathbf{z} = \mathbf{c}\]

with

\[f = 0 \qquad g = g_0 + g_1\]
\[g_0(\mathbf{z}_0) = (1/2) \| \mathbf{y} - \mathbf{z}_0 \|_2^2 \qquad g_1(\mathbf{z}_1) = \lambda \| \mathbf{z}_1 \|_{2,1}\]
\[\begin{split}A = \left( \begin{array}{c} C \\ D \end{array} \right) \qquad B = \left( \begin{array}{cc} -I & 0 \\ 0 & -I \end{array} \right) \qquad \mathbf{c} = \left( \begin{array}{c} 0 \\ 0 \end{array} \right) \;.\end{split}\]

This is a more complex splitting than that used in the companion example, but it allows the use of a proximal ADMM solver in a way that avoids the need for the conjugate gradient sub-iterations used by the ADMM solver in the companion example.

[3]:
𝛼 = 1e2  # improve problem conditioning by balancing C and D components of A
λ = 2e0  # ℓ2,1 norm regularization parameter
ρ = 5e-3  # ADMM penalty parameter
maxiter = 1000  # number of ADMM iterations

f = functional.ZeroFunctional()
g0 = loss.SquaredL2Loss(y=y)
g1 = (λ / 𝛼) * functional.L21Norm()
g = functional.SeparableFunctional((g0, g1))
D = linop.FiniteDifference(input_shape=tangle.shape, append=0)

A = linop.VerticalStack((C, 𝛼 * D))
mu, nu = ProximalADMM.estimate_parameters(A)

solver = ProximalADMM(
    f=f,
    g=g,
    A=A,
    B=None,
    rho=ρ,
    mu=mu,
    nu=nu,
    maxiter=maxiter,
    itstat_options={"display": True, "period": 50},
)

Run the solver.

[4]:
print(f"Solving on {device_info()}\n")
tangle_recon = solver.solve()

print(
    "TV Restruction\nSNR: %.2f (dB), MAE: %.3f"
    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))
)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Objective  Prml Rsdl  Dual Rsdl
-----------------------------------------------
   0  1.81e+00  1.121e+04  2.964e+04  2.964e+04
  50  1.02e+01  7.988e+05  1.818e+04  6.467e+02
 100  1.70e+01  6.934e+05  1.024e+04  5.291e+02
 150  2.38e+01  4.895e+05  5.486e+03  4.237e+02
 200  3.04e+01  5.058e+05  2.960e+03  3.089e+02
 250  3.73e+01  4.514e+05  2.242e+03  2.426e+02
 300  4.44e+01  4.201e+05  2.234e+03  1.673e+02
 350  5.14e+01  3.891e+05  1.946e+03  1.372e+02
 400  5.83e+01  3.742e+05  1.719e+03  1.019e+02
 450  6.50e+01  3.709e+05  1.127e+03  9.809e+01
 500  7.21e+01  3.699e+05  8.707e+02  7.260e+01
 550  7.88e+01  3.613e+05  6.581e+02  5.549e+01
 600  8.55e+01  3.579e+05  4.328e+02  4.778e+01
 650  9.23e+01  3.572e+05  3.813e+02  4.176e+01
 700  9.90e+01  3.572e+05  3.522e+02  3.024e+01
 750  1.06e+02  3.571e+05  2.972e+02  2.586e+01
 800  1.13e+02  3.529e+05  2.519e+02  2.231e+01
 850  1.19e+02  3.554e+05  1.873e+02  1.808e+01
 900  1.26e+02  3.556e+05  1.464e+02  1.466e+01
 950  1.32e+02  3.537e+05  1.069e+02  1.183e+01
 999  1.39e+02  3.544e+05  7.827e+01  9.771e+00
TV Restruction
SNR: 14.31 (dB), MAE: 0.048

Show the recovered volume.

[5]:
fig, ax = kplt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(7, 6))
kplt.imview(
    tangle[32],
    title="Ground truth",
    cmap=kplt.cm.viridis,
    show_cbar=None,
    ax=ax[0],
)
kplt.imview(
    tangle_recon[32],
    title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),
    cmap=kplt.cm.viridis,
    ax=ax[1],
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
fig.suptitle("Central slice on $z$ axis (axis 0)")
fig.tight_layout()
fig.show()

fig, ax = kplt.subplots(
    nrows=1,
    ncols=2,
    sharex=True,
    sharey=True,
    gridspec_kw={"width_ratios": [1, 1.08]},
    figsize=(13, 4),
)
kplt.imview(
    tangle[:, 128],
    title="Ground truth",
    cmap=kplt.cm.viridis,
    ax=ax[0],
)
kplt.imview(
    tangle_recon[:, 128],
    title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),
    cmap=kplt.cm.viridis,
    ax=ax[1],
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], ax=ax[1], cax=cax, label="arbitrary units")
fig.suptitle("Central slice on $y$ axis (axis 1)")
fig.tight_layout()
fig.show()
../_images/examples_ct_astra_3d_tv_padmm_9_0.png
../_images/examples_ct_astra_3d_tv_padmm_9_1.png