TV-Regularized Sparse-View CT Reconstruction (Multiple Projectors, Common Sinogram)#
This example demonstrates solution of a sparse-view CT reconstruction problem with isotropic total variation (TV) regularization
\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x}
\|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,\]
where \(A\) is the X-ray transform (the CT forward projection operator), \(\mathbf{y}\) is the sinogram, \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the desired image. The solution is computed and compared for all three 2D CT projectors available in scico, using a sinogram computed with the svmbir projector.
[1]:
import numpy as np
from xdesign import Foam, discrete_phantom
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray import Parallel2dProjector, XRayTransform, astra, svmbir
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()
Create a ground truth image.
[2]:
N = 512 # phantom size
np.random.seed(1234)
x_gt = snp.array(discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N))
Define CT geometry and construct array of (approximately) equivalent projectors.
[3]:
n_projection = 45 # number of projections
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
projectors = {
"astra": astra.XRayTransform(x_gt.shape, 1, N, angles - np.pi / 2.0), # astra
"svmbir": svmbir.XRayTransform(x_gt.shape, 2 * np.pi - angles, N), # svmbir
"scico": XRayTransform(Parallel2dProjector((N, N), angles, det_count=N)), # scico
}
Compute common sinogram using svmbir projector.
[4]:
A = projectors["svmbir"]
noise = np.random.normal(size=(n_projection, N)).astype(np.float32)
y = A @ x_gt + 2.0 * noise
Solve the same problem using the different projectors.
[5]:
print(f"Solving on {device_info()}")
x_rec, hist = {}, {}
for p in ("astra", "svmbir", "scico"):
print(f"\nSolving with {p} projector")
# Set up ADMM solver object.
λ = 2e0 # L1 norm regularization parameter
ρ = 5e0 # ADMM penalty parameter
maxiter = 25 # number of ADMM iterations
cg_tol = 1e-4 # CG relative tolerance
cg_maxiter = 25 # maximum CG iterations per ADMM iteration
# The append=0 option makes the results of horizontal and vertical
# finite differences the same shape, which is required for the L21Norm,
# which is used so that g(Cx) corresponds to isotropic TV.
C = linop.FiniteDifference(input_shape=x_gt.shape, append=0)
g = λ * functional.L21Norm()
A = projectors[p]
f = loss.SquaredL2Loss(y=y, A=A)
x0 = snp.clip(A.T(y), 0, 1.0)
# Set up the solver.
solver = ADMM(
f=f,
g_list=[g],
C_list=[C],
rho_list=[ρ],
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 5},
)
# Run the solver.
solver.solve()
hist[p] = solver.itstat_object.history(transpose=True)
x_rec[p] = snp.clip(solver.x, 0, 1.0)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Solving with astra projector
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 3.11e+00 9.371e+03 1.197e+02 3.687e+00 25 5.578e-04
5 1.14e+01 3.389e+04 4.048e+01 7.760e+01 18 9.097e-05
10 1.54e+01 3.753e+04 3.316e+01 4.231e+01 11 9.641e-05
15 1.84e+01 3.793e+04 3.177e+01 3.489e+01 12 9.439e-05
20 1.99e+01 3.853e+04 1.159e+01 5.353e+00 1 8.981e-05
24 2.15e+01 3.865e+04 1.234e+01 5.924e+00 0 9.468e-05
Solving with svmbir projector
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 1.93e+01 5.671e+03 1.249e+02 5.708e+00 25 4.341e-04
5 1.03e+02 3.272e+04 3.791e+01 8.424e+01 16 7.794e-05
10 1.44e+02 3.662e+04 2.992e+01 5.590e+01 15 9.741e-05
15 1.77e+02 3.719e+04 1.295e+01 2.983e+01 11 9.084e-05
20 1.94e+02 3.758e+04 1.129e+01 5.027e+00 1 6.963e-05
24 2.00e+02 3.793e+04 1.018e+01 4.801e+00 0 9.264e-05
Solving with scico projector
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 2.55e-01 1.149e+04 1.209e+02 8.344e+00 25 7.804e-04
5 6.77e-01 3.484e+04 4.156e+01 7.835e+01 15 9.443e-05
10 9.20e-01 3.844e+04 3.479e+01 4.478e+01 12 9.975e-05
15 1.10e+00 3.896e+04 1.398e+01 2.853e+01 11 9.003e-05
20 1.21e+00 3.934e+04 2.256e+01 1.237e+01 0 7.747e-05
24 1.34e+00 3.937e+04 2.291e+01 9.310e+00 0 8.819e-05
Display sinogram.
[6]:
fig, ax = plot.subplots(nrows=1, ncols=1, figsize=(15, 3))
plot.imview(y, title="sinogram", fig=fig, ax=ax)
fig.show()
Plot convergence statistics.
[7]:
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(12, 5))
plot.plot(
np.vstack([hist[p].Objective for p in projectors.keys()]).T,
title="Objective function",
xlbl="Iteration",
ylbl="Functional value",
lgnd=projectors.keys(),
fig=fig,
ax=ax[0],
)
plot.plot(
np.vstack([hist[p].Prml_Rsdl for p in projectors.keys()]).T,
ptyp="semilogy",
title="Primal Residual",
xlbl="Iteration",
fig=fig,
ax=ax[1],
)
plot.plot(
np.vstack([hist[p].Dual_Rsdl for p in projectors.keys()]).T,
ptyp="semilogy",
title="Dual Residual",
xlbl="Iteration",
fig=fig,
ax=ax[2],
)
fig.show()
Show the recovered images.
[8]:
fig, ax = plot.subplots(nrows=1, ncols=4, figsize=(15, 5))
plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0])
for n, p in enumerate(projectors.keys()):
plot.imview(
x_rec[p],
title="%s SNR: %.2f (dB)" % (p, metric.snr(x_gt, x_rec[p])),
fig=fig,
ax=ax[n + 1],
)
fig.show()