TV-Regularized Low-Dose CT Reconstruction¶
This example demonstrates solution of a low-dose CT reconstruction problem with isotropic total variation (TV) regularization
where \(A\) is the X-ray transform (the CT forward projection), \(\mathbf{y}\) is the sinogram, the norm weighting \(W\) is chosen so that the weighted norm is an approximation to the Poisson negative log likelihood [51], \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the reconstructed image.
[1]:
import numpy as np
from xdesign import Soil, discrete_phantom
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray.astra import XRayTransform2D
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()
Create a ground truth image.
[2]:
N = 512 # phantom size
np.random.seed(0)
x_gt = discrete_phantom(Soil(porosity=0.80), size=384)
x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))
x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values
x_gt = snp.array(x_gt) # convert to jax type
Configure CT projection operator and generate synthetic measurements.
[3]:
n_projection = 360 # number of projections
Io = 1e3 # source flux
𝛼 = 1e-2 # attenuation coefficient
angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
y_c = A @ x_gt # sinogram
Add Poisson noise to projections according to
We use the NumPy random functionality so we can generate using 64-bit numbers.
[4]:
counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt))
counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1
y = -1 / 𝛼 * np.log(counts / Io)
y = snp.array(y) # convert back to float32 as a jax array
Set up post processing. For this example, we clip all reconstructions to the range of the ground truth.
[5]:
def postprocess(x):
return snp.clip(x, 0, snp.max(x_gt))
Compute an FBP reconstruction as an initial guess.
[6]:
x0 = postprocess(A.fbp(y))
Set up and solve the un-weighted reconstruction problem
[7]:
# Note that rho and lambda were selected via a parameter sweep (not
# shown here).
ρ = 2.5e3 # ADMM penalty parameter
lambda_unweighted = 3e2 # regularization strength
maxiter = 100 # number of ADMM iterations
cg_tol = 1e-5 # CG relative tolerance
cg_maxiter = 10 # maximum CG iterations per ADMM iteration
f = loss.SquaredL2Loss(y=y, A=A)
admm_unweighted = ADMM(
f=f,
g_list=[lambda_unweighted * functional.L21Norm()],
C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
rho_list=[ρ],
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
admm_unweighted.solve()
x_unweighted = postprocess(admm_unweighted.x)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 2.73e+00 1.642e+07 2.780e+03 2.876e+05 10 3.260e-04
10 1.58e+01 5.226e+06 1.041e+02 6.145e+03 10 2.410e-05
20 2.68e+01 5.273e+06 3.205e+01 1.421e+03 7 8.104e-06
30 3.54e+01 5.285e+06 1.971e+01 6.770e+02 5 8.267e-06
40 4.20e+01 5.291e+06 1.441e+01 4.295e+02 3 9.066e-06
50 4.71e+01 5.294e+06 1.132e+01 2.884e+02 3 9.611e-06
60 5.24e+01 5.296e+06 9.409e+00 1.575e+02 3 8.937e-06
70 5.77e+01 5.298e+06 7.976e+00 1.216e+02 3 7.751e-06
80 6.25e+01 5.299e+06 6.902e+00 1.081e+02 3 6.539e-06
90 6.66e+01 5.300e+06 6.073e+00 6.031e+01 1 5.481e-06
99 7.01e+01 5.301e+06 5.595e+00 1.021e+02 3 5.866e-06
Set up and solve the weighted reconstruction problem
where
The data fidelity term in this formulation follows [51] (9) except for the scaling by \(I_0\), which we use to maintain balance between the data and regularization terms if \(I_0\) changes.
[8]:
lambda_weighted = 5e1
weights = snp.array(counts / Io)
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))
admm_weighted = ADMM(
f=f,
g_list=[lambda_weighted * functional.L21Norm()],
C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
rho_list=[ρ],
maxiter=maxiter,
x0=x0,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
print()
admm_weighted.solve()
x_weighted = postprocess(admm_weighted.x)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 1.42e+00 3.997e+06 4.925e+02 5.000e+04 10 1.128e-03
10 1.43e+01 2.253e+06 6.515e+01 3.768e+04 10 1.458e-04
20 2.69e+01 1.425e+06 5.424e+01 2.406e+04 10 1.564e-04
30 3.94e+01 1.113e+06 3.917e+01 1.406e+04 10 9.754e-05
40 5.22e+01 1.031e+06 2.517e+01 5.847e+03 10 5.745e-05
50 6.56e+01 1.022e+06 8.966e+00 1.788e+03 10 1.618e-05
60 7.86e+01 1.021e+06 4.037e+00 7.585e+02 9 8.861e-06
70 8.87e+01 1.021e+06 2.294e+00 4.765e+02 6 9.732e-06
80 9.69e+01 1.021e+06 1.632e+00 3.534e+02 5 9.094e-06
90 1.04e+02 1.021e+06 1.231e+00 2.752e+02 5 7.198e-06
99 1.10e+02 1.021e+06 1.023e+00 2.246e+02 3 9.953e-06
Show recovered images.
[9]:
def plot_recon(x, title, ax):
"""Plot an image with title indicating error metrics."""
plot.imview(
x,
title=f"{title}\nSNR: {metric.snr(x_gt, x):.2f} (dB), MAE: {metric.mae(x_gt, x):.3f}",
fig=fig,
ax=ax,
)
fig, ax = plot.subplots(nrows=2, ncols=2, figsize=(11, 10))
plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0])
plot_recon(x0, "FBP Reconstruction", ax=ax[0, 1])
plot_recon(x_unweighted, "Unweighted TV Reconstruction", ax=ax[1, 0])
plot_recon(x_weighted, "Weighted TV Reconstruction", ax=ax[1, 1])
for ax_ in ax.ravel():
ax_.set_xlim(64, 448)
ax_.set_ylim(64, 448)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="arbitrary units"
)
fig.show()