3D TV-Regularized Sparse-View CT Reconstruction#

This example demonstrates solution of a sparse-view, 3D CT reconstruction problem with isotropic total variation (TV) regularization

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,\]

where \(A\) is the X-ray transform (the CT forward projection operator), \(\mathbf{y}\) is the sinogram, \(C\) is a 3D finite difference operator, and \(\mathbf{x}\) is the desired image.

[1]:
import numpy as np

from mpl_toolkits.axes_grid1 import make_axes_locatable

import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_tangle_phantom
from scico.linop.xray.astra import XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()

Create a ground truth image and projector.

[2]:
Nx = 128
Ny = 256
Nz = 64

tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))

n_projection = 10  # number of projections
angles = np.linspace(0, np.pi, n_projection)  # evenly spaced projection angles
A = XRayTransform(tangle.shape, [1.0, 1.0], [Nz, max(Nx, Ny)], angles)  # CT projection operator
y = A @ tangle  # sinogram

Set up ADMM solver object.

[3]:
λ = 2e0  # ℓ2,1 norm regularization parameter
ρ = 5e0  # ADMM penalty parameter
maxiter = 25  # number of ADMM iterations
cg_tol = 1e-4  # CG relative tolerance
cg_maxiter = 25  # maximum CG iterations per ADMM iteration

# The append=0 option makes the results of horizontal and vertical
# finite differences the same shape, which is required for the L21Norm,
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=tangle.shape, append=0)
g = λ * functional.L21Norm()

f = loss.SquaredL2Loss(y=y, A=A)

x0 = A.T(y)

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[ρ],
    x0=x0,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
    itstat_options={"display": True, "period": 5},
)

Run the solver.

[4]:
print(f"Solving on {device_info()}\n")
solver.solve()
hist = solver.itstat_object.history(transpose=True)
tangle_recon = solver.x

print(
    "TV Restruction\nSNR: %.2f (dB), MAE: %.3f"
    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))
)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Objective  Prml Rsdl  Dual Rsdl  CG It  CG Res
-----------------------------------------------------------------
   0  4.30e+00  1.559e+08  1.098e+03  4.701e+05     25  3.991e-03
   5  9.62e+00  1.861e+05  1.294e+02  5.987e+02     25  1.953e-04
  10  1.34e+01  1.657e+05  4.562e+01  1.051e+02     15  9.501e-05
  15  1.60e+01  1.641e+05  3.663e+01  5.942e+01     19  9.018e-05
  20  1.83e+01  1.632e+05  2.240e+01  3.613e+01      7  9.068e-05
  24  1.93e+01  1.631e+05  1.294e+01  8.363e+00      3  8.906e-05
TV Restruction
SNR: 20.20 (dB), MAE: 0.017

Show the recovered image.

[5]:
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 5))
plot.imview(tangle[32], title="Ground truth (central slice)", cbar=None, fig=fig, ax=ax[0])

plot.imview(
    tangle_recon[32],
    title="TV Reconstruction (central slice)\nSNR: %.2f (dB), MAE: %.3f"
    % (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),
    fig=fig,
    ax=ax[1],
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
../_images/examples_ct_astra_3d_tv_admm_9_0.png