TV-Regularized Low-Dose CT Reconstruction#
This example demonstrates solution of a low-dose CT reconstruction problem with isotropic total variation (TV) regularization
where \(A\) is the X-ray transform (the CT forward projection), \(\mathbf{y}\) is the sinogram, the norm weighting \(W\) is chosen so that the weighted norm is an approximation to the Poisson negative log likelihood [45], \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the desired image.
[1]:
import numpy as np
from xdesign import Soil, discrete_phantom
import scico.numpy as snp
from scico import functional, linop, loss, metric, plot
from scico.linop.xray.astra import XRayTransform
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()
Create a ground truth image.
[2]:
N = 512 # phantom size
np.random.seed(0)
x_gt = discrete_phantom(Soil(porosity=0.80), size=384)
x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))
x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values
x_gt = snp.array(x_gt) # convert to jax type
Configure CT projection operator and generate synthetic measurements.
[3]:
n_projection = 360 # number of projections
Io = 1e3 # source flux
𝛼 = 1e-2 # attenuation coefficient
angles = np.linspace(0, 2 * np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform(x_gt.shape, 1.0, N, angles) # CT projection operator
y_c = A @ x_gt # sinogram
Add Poisson noise to projections according to
We use the NumPy random functionality so we can generate using 64-bit numbers.
[4]:
counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt))
counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1
y = -1 / 𝛼 * np.log(counts / Io)
y = snp.array(y) # convert back to float32 as a jax array
Set up post processing. For this example, we clip all reconstructions to the range of the ground truth.
[5]:
def postprocess(x):
return snp.clip(x, 0, snp.max(x_gt))
Compute an FBP reconstruction as an initial guess.
[6]:
x0 = postprocess(A.fbp(y))
Set up and solve the un-weighted reconstruction problem
[7]:
# Note that rho and lambda were selected via a parameter sweep (not
# shown here).
ρ = 2.5e3 # ADMM penalty parameter
lambda_unweighted = 3e2 # regularization strength
maxiter = 100 # number of ADMM iterations
cg_tol = 1e-5 # CG relative tolerance
cg_maxiter = 10 # maximum CG iterations per ADMM iteration
f = loss.SquaredL2Loss(y=y, A=A)
admm_unweighted = ADMM(
f=f,
g_list=[lambda_unweighted * functional.L21Norm()],
C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
rho_list=[ρ],
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
admm_unweighted.solve()
x_unweighted = postprocess(admm_unweighted.x)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 2.64e+00 2.404e+07 5.755e+01 1.197e+02 10 5.986e-04
10 8.47e+00 5.200e+06 4.310e+00 5.302e+00 10 4.992e-05
20 1.08e+01 5.267e+06 7.188e-01 6.798e-01 10 1.015e-05
30 1.28e+01 5.280e+06 4.211e-01 3.066e-01 5 9.675e-06
40 1.44e+01 5.287e+06 2.976e-01 1.712e-01 4 9.746e-06
50 1.59e+01 5.290e+06 2.312e-01 9.611e-02 1 9.044e-06
60 1.73e+01 5.293e+06 1.903e-01 7.309e-02 3 6.971e-06
70 1.87e+01 5.294e+06 1.610e-01 5.191e-02 3 6.897e-06
80 2.00e+01 5.295e+06 1.387e-01 3.472e-02 2 8.061e-06
90 2.13e+01 5.296e+06 1.206e-01 4.024e-02 3 5.555e-06
99 2.24e+01 5.297e+06 1.081e-01 2.486e-02 2 7.904e-06
Set up and solve the weighted reconstruction problem
where
The data fidelity term in this formulation follows [45] (9) except for the scaling by \(I_0\), which we use to maintain balance between the data and regularization terms if \(I_0\) changes.
[8]:
lambda_weighted = 5e1
weights = snp.array(counts / Io)
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))
admm_weighted = ADMM(
f=f,
g_list=[lambda_weighted * functional.L21Norm()],
C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
rho_list=[ρ],
maxiter=maxiter,
x0=x0,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
admm_weighted.solve()
x_weighted = postprocess(admm_weighted.x)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 5.76e-01 5.315e+06 9.950e+00 2.064e+01 10 1.404e-03
10 3.29e+00 3.439e+06 9.724e-01 1.760e+01 10 1.081e-04
20 5.77e+00 2.190e+06 1.083e+00 1.369e+01 10 1.158e-04
30 8.22e+00 1.486e+06 9.219e-01 9.942e+00 10 9.922e-05
40 1.07e+01 1.144e+06 8.045e-01 6.458e+00 10 8.325e-05
50 1.31e+01 1.036e+06 5.633e-01 2.883e+00 10 5.363e-05
60 1.55e+01 1.022e+06 2.147e-01 8.679e-01 10 1.792e-05
70 1.79e+01 1.021e+06 8.207e-02 3.505e-01 9 8.761e-06
80 1.99e+01 1.021e+06 4.981e-02 2.153e-01 7 8.721e-06
90 2.16e+01 1.021e+06 3.444e-02 1.583e-01 5 9.735e-06
99 2.31e+01 1.021e+06 2.628e-02 1.263e-01 5 8.050e-06
Show recovered images.
[9]:
def plot_recon(x, title, ax):
"""Plot an image with title indicating error metrics."""
plot.imview(
x,
title=f"{title}\nSNR: {metric.snr(x_gt, x):.2f} (dB), MAE: {metric.mae(x_gt, x):.3f}",
fig=fig,
ax=ax,
)
fig, ax = plot.subplots(nrows=2, ncols=2, figsize=(11, 10))
plot.imview(x_gt, title="Ground truth", fig=fig, ax=ax[0, 0])
plot_recon(x0, "FBP Reconstruction", ax=ax[0, 1])
plot_recon(x_unweighted, "Unweighted TV Reconstruction", ax=ax[1, 0])
plot_recon(x_weighted, "Weighted TV Reconstruction", ax=ax[1, 1])
for ax_ in ax.ravel():
ax_.set_xlim(64, 448)
ax_.set_ylim(64, 448)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="arbitrary units"
)
fig.show()