3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)¶
This example demonstrates solution of a sparse-view, 3D CT reconstruction problem with isotropic total variation (TV) regularization
where \(C\) is the X-ray transform (the CT forward projection operator), \(\mathbf{y}\) is the sinogram, \(D\) is a 3D finite difference operator, and \(\mathbf{x}\) is the reconstructed image.
In this example the problem is solved via proximal ADMM, while standard ADMM is used in a companion example.
[1]:
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.examples import create_tangle_phantom
from scico.linop.xray.astra import XRayTransform3D, angle_to_vector
from scico.optimize import ProximalADMM
from scico.util import device_info
plot.config_notebook_plotting()
Create a ground truth image and projector.
[2]:
Nx = 128
Ny = 256
Nz = 64
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))
n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
det_spacing = [1.0, 1.0]
det_count = [Nz, max(Nx, Ny)]
vectors = angle_to_vector(det_spacing, angles)
# It would have been more straightforward to use the det_spacing and angles keywords
# in this case (since vectors is just computed directly from these two quantities), but
# the more general form is used here as a demonstration.
C = XRayTransform3D(tangle.shape, det_count=det_count, vectors=vectors) # CT projection operator
y = C @ tangle # sinogram
Set up problem and solver. We want to minimize the functional
where \(C\) is the X-ray transform and \(D\) is a finite difference operator. This problem can be expressed as
which can be written in the form of a standard ADMM problem
with
This is a more complex splitting than that used in the companion example, but it allows the use of a proximal ADMM solver in a way that avoids the need for the conjugate gradient sub-iterations used by the ADMM solver in the companion example.
[3]:
𝛼 = 1e2 # improve problem conditioning by balancing C and D components of A
λ = 2e0 / 𝛼 # ℓ2,1 norm regularization parameter
ρ = 5e-3 # ADMM penalty parameter
maxiter = 1000 # number of ADMM iterations
f = functional.ZeroFunctional()
g0 = loss.SquaredL2Loss(y=y)
g1 = λ * functional.L21Norm()
g = functional.SeparableFunctional((g0, g1))
D = linop.FiniteDifference(input_shape=tangle.shape, append=0)
A = linop.VerticalStack((C, 𝛼 * D))
mu, nu = ProximalADMM.estimate_parameters(A)
solver = ProximalADMM(
f=f,
g=g,
A=A,
B=None,
rho=ρ,
mu=mu,
nu=nu,
maxiter=maxiter,
itstat_options={"display": True, "period": 50},
)
Run the solver.
[4]:
print(f"Solving on {device_info()}\n")
tangle_recon = solver.solve()
print(
"TV Restruction\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))
)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl
-----------------------------------------------
0 1.19e+00 1.361e+04 3.267e+04 3.267e+04
50 4.39e+00 4.273e+05 2.009e+04 4.357e+02
100 7.54e+00 4.010e+05 1.103e+04 4.197e+02
150 1.68e+01 2.925e+05 5.500e+03 3.164e+02
200 2.33e+01 2.753e+05 2.533e+03 2.388e+02
250 2.87e+01 2.249e+05 1.595e+03 1.623e+02
300 3.55e+01 2.052e+05 1.648e+03 1.128e+02
350 3.90e+01 1.892e+05 1.632e+03 7.943e+01
400 4.25e+01 1.815e+05 1.385e+03 6.610e+01
450 4.93e+01 1.719e+05 9.755e+02 6.438e+01
500 5.49e+01 1.728e+05 6.477e+02 4.537e+01
550 5.88e+01 1.693e+05 4.373e+02 3.503e+01
600 6.28e+01 1.682e+05 2.703e+02 2.917e+01
650 6.60e+01 1.677e+05 1.962e+02 2.534e+01
700 6.99e+01 1.664e+05 1.915e+02 1.911e+01
750 7.42e+01 1.665e+05 1.896e+02 1.782e+01
800 7.85e+01 1.655e+05 1.776e+02 1.507e+01
850 8.25e+01 1.658e+05 1.504e+02 1.204e+01
900 8.64e+01 1.655e+05 1.167e+02 1.160e+01
950 9.02e+01 1.654e+05 8.773e+01 9.069e+00
999 9.37e+01 1.655e+05 6.540e+01 8.168e+00
TV Restruction
SNR: 22.04 (dB), MAE: 0.009
Show the recovered image.
[5]:
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(7, 6))
plot.imview(
tangle[32],
title="Ground truth (central slice)",
cmap=plot.cm.Blues,
cbar=None,
fig=fig,
ax=ax[0],
)
plot.imview(
tangle_recon[32],
title="TV Reconstruction (central slice)\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),
cmap=plot.cm.Blues,
fig=fig,
ax=ax[1],
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
fig.show()