3D TV-Regularized Sparse-View CT Reconstruction (Proximal ADMM Solver)¶
This example demonstrates solution of a sparse-view, 3D CT reconstruction problem with isotropic total variation (TV) regularization
where \(C\) is the X-ray transform (the CT forward projection operator), \(\mathbf{y}\) is the sinogram, \(D\) is a 3D finite difference operator, and \(\mathbf{x}\) is the reconstructed image.
In this example the problem is solved via proximal ADMM, while standard ADMM is used in a companion example.
[1]:
import numpy as np
import komplot as kplt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.examples import create_tangle_phantom
from scico.linop.xray.astra import XRayTransform3D, angle_to_vector
from scico.optimize import ProximalADMM
from scico.util import device_info
kplt.config_notebook_plotting()
Create a ground truth image and projector.
[2]:
Nx = 128
Ny = 256
Nz = 64
tangle = snp.array(create_tangle_phantom(Nx, Ny, Nz))
n_projection = 10 # number of projections
angles = np.linspace(0, np.pi, n_projection, endpoint=False) # evenly spaced projection angles
det_spacing = [1.0, 1.0]
det_count = [Nz, max(Nx, Ny)]
vectors = angle_to_vector(det_spacing, angles)
# It would have been more straightforward to use the det_spacing and angles keywords
# in this case (since vectors is just computed directly from these two quantities), but
# the more general form is used here as a demonstration.
C = XRayTransform3D(tangle.shape, det_count=det_count, vectors=vectors) # CT projection operator
y = C @ tangle # sinogram
Set up problem and solver. We want to minimize the functional
where \(C\) is the X-ray transform and \(D\) is a finite difference operator. This problem can be expressed as
which can be written in the form of a standard ADMM problem
with
This is a more complex splitting than that used in the companion example, but it allows the use of a proximal ADMM solver in a way that avoids the need for the conjugate gradient sub-iterations used by the ADMM solver in the companion example.
[3]:
𝛼 = 1e2 # improve problem conditioning by balancing C and D components of A
λ = 2e0 # ℓ2,1 norm regularization parameter
ρ = 5e-3 # ADMM penalty parameter
maxiter = 1000 # number of ADMM iterations
f = functional.ZeroFunctional()
g0 = loss.SquaredL2Loss(y=y)
g1 = (λ / 𝛼) * functional.L21Norm()
g = functional.SeparableFunctional((g0, g1))
D = linop.FiniteDifference(input_shape=tangle.shape, append=0)
A = linop.VerticalStack((C, 𝛼 * D))
mu, nu = ProximalADMM.estimate_parameters(A)
solver = ProximalADMM(
f=f,
g=g,
A=A,
B=None,
rho=ρ,
mu=mu,
nu=nu,
maxiter=maxiter,
itstat_options={"display": True, "period": 50},
)
Run the solver.
[4]:
print(f"Solving on {device_info()}\n")
tangle_recon = solver.solve()
print(
"TV Restruction\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon))
)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl
-----------------------------------------------
0 1.81e+00 1.121e+04 2.964e+04 2.964e+04
50 1.02e+01 7.988e+05 1.818e+04 6.467e+02
100 1.70e+01 6.934e+05 1.024e+04 5.291e+02
150 2.38e+01 4.895e+05 5.486e+03 4.237e+02
200 3.04e+01 5.058e+05 2.960e+03 3.089e+02
250 3.73e+01 4.514e+05 2.242e+03 2.426e+02
300 4.44e+01 4.201e+05 2.234e+03 1.673e+02
350 5.14e+01 3.891e+05 1.946e+03 1.372e+02
400 5.83e+01 3.742e+05 1.719e+03 1.019e+02
450 6.50e+01 3.709e+05 1.127e+03 9.809e+01
500 7.21e+01 3.699e+05 8.707e+02 7.260e+01
550 7.88e+01 3.613e+05 6.581e+02 5.549e+01
600 8.55e+01 3.579e+05 4.328e+02 4.778e+01
650 9.23e+01 3.572e+05 3.813e+02 4.176e+01
700 9.90e+01 3.572e+05 3.522e+02 3.024e+01
750 1.06e+02 3.571e+05 2.972e+02 2.586e+01
800 1.13e+02 3.529e+05 2.519e+02 2.231e+01
850 1.19e+02 3.554e+05 1.873e+02 1.808e+01
900 1.26e+02 3.556e+05 1.464e+02 1.466e+01
950 1.32e+02 3.537e+05 1.069e+02 1.183e+01
999 1.39e+02 3.544e+05 7.827e+01 9.771e+00
TV Restruction
SNR: 14.31 (dB), MAE: 0.048
Show the recovered volume.
[5]:
fig, ax = kplt.subplots(nrows=1, ncols=2, sharex=True, sharey=True, figsize=(7, 6))
kplt.imview(
tangle[32],
title="Ground truth",
cmap=kplt.cm.viridis,
show_cbar=None,
ax=ax[0],
)
kplt.imview(
tangle_recon[32],
title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),
cmap=kplt.cm.viridis,
ax=ax[1],
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], cax=cax, label="arbitrary units")
fig.suptitle("Central slice on $z$ axis (axis 0)")
fig.tight_layout()
fig.show()
fig, ax = kplt.subplots(
nrows=1,
ncols=2,
sharex=True,
sharey=True,
gridspec_kw={"width_ratios": [1, 1.08]},
figsize=(13, 4),
)
kplt.imview(
tangle[:, 128],
title="Ground truth",
cmap=kplt.cm.viridis,
ax=ax[0],
)
kplt.imview(
tangle_recon[:, 128],
title="TV Reconstruction\nSNR: %.2f (dB), MAE: %.3f"
% (metric.snr(tangle, tangle_recon), metric.mae(tangle, tangle_recon)),
cmap=kplt.cm.viridis,
ax=ax[1],
)
divider = make_axes_locatable(ax[1])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[1].get_images()[0], ax=ax[1], cax=cax, label="arbitrary units")
fig.suptitle("Central slice on $y$ axis (axis 1)")
fig.tight_layout()
fig.show()