TV-Regularized Sparse-View CT Reconstruction (Integrated Projector)

This example demonstrates solution of a sparse-view CT reconstruction problem with isotropic total variation (TV) regularization

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,\]

where \(A\) is the X-ray transform (the CT forward projection operator), \(\mathbf{y}\) is the sinogram, \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the reconstructed image. This example uses the CT projector integrated into scico, while the companion example script uses the projector provided by the astra package.

[1]:
import numpy as np

import komplot as kplt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from xdesign import Foam, discrete_phantom

import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.linop.xray import XRayTransform2D
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
kplt.config_notebook_plotting()

Create a ground truth image.

[2]:
N = 512  # phantom size
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))

Configure CT projection operator and generate synthetic measurements.

[3]:
n_projection = 45  # number of projections
angles = np.linspace(0, np.pi, n_projection, endpoint=False)  # evenly spaced projection angles
det_count = int(N * 1.05 / np.sqrt(2.0))
dx = 1.0 / np.sqrt(2)
A = XRayTransform2D(
    (N, N), angles + np.pi / 2.0, det_count=det_count, dx=dx
)  # CT projection operator
y = A @ x_gt  # sinogram

Set up problem functional and ADMM solver object.

[4]:
λ = 2e0  # ℓ1 norm regularization parameter
ρ = 5e0  # ADMM penalty parameter
maxiter = 25  # number of ADMM iterations
cg_tol = 1e-4  # CG relative tolerance
cg_maxiter = 25  # maximum CG iterations per ADMM iteration

# The append=0 option makes the results of horizontal and vertical
# finite differences the same shape, which is required for the L21Norm,
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
g = λ * functional.L21Norm()
f = loss.SquaredL2Loss(y=y, A=A)
x0 = snp.clip(A.fbp(y), 0, 1.0)

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[ρ],
    x0=x0,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
    itstat_options={"display": True, "period": 5},
)

Run the solver.

[5]:
print(f"Solving on {device_info()}\n")
solver.solve()
hist = solver.itstat_object.history(transpose=True)
x_reconstruction = snp.clip(solver.x, 0, 1.0)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Objective  Prml Rsdl  Dual Rsdl  CG It  CG Res
-----------------------------------------------------------------
   0  2.16e+00  1.645e+03  1.800e+02  8.801e+02     20  8.980e-05
   5  3.92e+00  2.151e+04  4.772e+01  6.780e+01     15  9.016e-05
  10  4.17e+00  2.688e+04  4.667e+01  3.189e+01      0  9.435e-05
  15  4.33e+00  2.776e+04  2.067e+01  1.424e+01      0  9.881e-05
  20  4.48e+00  2.801e+04  1.491e+01  6.985e+00      0  9.695e-05
  24  4.69e+00  2.758e+04  1.472e+01  2.252e+01     12  8.992e-05

Show the recovered image.

[6]:
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(15, 5))
kplt.imview(x_gt, title="Ground truth", cmap="Blues", show_cbar=None, ax=ax[0])
kplt.imview(
    x0,
    title="FBP Reconstruction: \nSNR: %.2f (dB), MAE: %.3f"
    % (metric.snr(x_gt, x0), metric.mae(x_gt, x0)),
    cmap="Blues",
    show_cbar=None,
    ax=ax[1],
)
kplt.imview(
    x_reconstruction,
    title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
    % (metric.snr(x_gt, x_reconstruction), metric.mae(x_gt, x_reconstruction)),
    cmap="Blues",
    ax=ax[2],
)
divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
../_images/examples_ct_tv_admm_11_0.png

Plot convergence statistics.

[7]:
fig, ax = kplt.subplots(nrows=1, ncols=2, figsize=(12, 5))
kplt.plot(
    hist.Objective,
    title="Objective function",
    xlabel="Iteration",
    ylabel="Functional value",
    ax=ax[0],
)
kplt.plot(
    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
    ylog=True,
    title="Residuals",
    xlabel="Iteration",
    legend=("Primal", "Dual"),
    ax=ax[1],
)
fig.show()
../_images/examples_ct_tv_admm_13_0.png