TV-Regularized Low-Dose CT Reconstruction¶
This example demonstrates solution of a low-dose CT reconstruction problem with isotropic total variation (TV) regularization
where \(A\) is the X-ray transform (the CT forward projection), \(\mathbf{y}\) is the sinogram, the norm weighting \(W\) is chosen so that the weighted norm is an approximation to the Poisson negative log likelihood [51], \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the reconstructed image.
[1]:
import warnings
import numpy as np
import komplot as kplt
from xdesign import Soil, discrete_phantom
import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.linop.xray.astra import XRayTransform2D
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
kplt.config_notebook_plotting()
Create a ground truth image.
[2]:
N = 512 # phantom size
np.random.seed(0)
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=RuntimeWarning)
x_gt = discrete_phantom(Soil(porosity=0.60), size=384)
x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))
x_gt = np.clip(x_gt, 0, np.inf) # clip to positive values
x_gt = snp.array(x_gt) # convert to jax type
Configure CT projection operator and generate synthetic measurements.
[3]:
n_projection = 360 # number of projections
Io = 1e3 # source flux
𝛼 = 1e-2 # attenuation coefficient
angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False) # evenly spaced projection angles
A = XRayTransform2D(x_gt.shape, N, 1.0, angles) # CT projection operator
y_c = A @ x_gt # sinogram
Add Poisson noise to projections according to
We use the NumPy random functionality so we can generate using 64-bit numbers.
[4]:
counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt))
counts = np.clip(counts, a_min=1, a_max=np.inf) # replace any 0s count with 1
y = -1 / 𝛼 * np.log(counts / Io)
y = snp.array(y) # convert back to float32 as a jax array
Set up post processing. For this example, we clip all reconstructions to the range of the ground truth.
[5]:
def postprocess(x):
return snp.clip(x, 0, snp.max(x_gt))
Compute an FBP reconstruction as an initial guess.
[6]:
x0 = postprocess(A.fbp(y))
Set up and solve the un-weighted reconstruction problem
[7]:
# Note that rho and lambda were selected via a parameter sweep (not
# shown here).
ρ = 2.5e3 # ADMM penalty parameter
lambda_unweighted = 3e2 # regularization strength
maxiter = 100 # number of ADMM iterations
cg_tol = 1e-5 # CG relative tolerance
cg_maxiter = 10 # maximum CG iterations per ADMM iteration
f = loss.SquaredL2Loss(y=y, A=A)
admm_unweighted = ADMM(
f=f,
g_list=[lambda_unweighted * functional.L21Norm()],
C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
rho_list=[ρ],
x0=x0,
maxiter=maxiter,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
admm_unweighted.solve()
x_unweighted = postprocess(admm_unweighted.x)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 2.98e+00 2.443e+07 2.848e+03 3.000e+05 10 2.274e-04
10 9.69e+00 7.813e+06 2.190e+02 2.417e+04 10 4.892e-05
20 1.54e+01 7.833e+06 3.956e+01 2.157e+03 6 9.729e-06
30 1.97e+01 7.847e+06 2.192e+01 9.447e+02 5 8.738e-06
40 2.36e+01 7.854e+06 1.531e+01 5.511e+02 2 9.673e-06
50 2.68e+01 7.857e+06 1.167e+01 3.806e+02 3 7.047e-06
60 3.02e+01 7.860e+06 9.566e+00 1.990e+02 3 7.704e-06
70 3.36e+01 7.862e+06 7.893e+00 1.271e+02 2 9.199e-06
80 3.68e+01 7.863e+06 6.747e+00 1.383e+02 3 5.676e-06
90 3.94e+01 7.864e+06 5.912e+00 2.206e+01 0 9.402e-06
99 4.19e+01 7.864e+06 5.361e+00 1.660e+01 0 9.550e-06
Set up and solve the weighted reconstruction problem
where
The data fidelity term in this formulation follows [51] (9) except for the scaling by \(I_0\), which we use to maintain balance between the data and regularization terms if \(I_0\) changes.
[8]:
lambda_weighted = 5e1
weights = snp.array(counts / Io)
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))
admm_weighted = ADMM(
f=f,
g_list=[lambda_weighted * functional.L21Norm()],
C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
rho_list=[ρ],
maxiter=maxiter,
x0=x0,
subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
itstat_options={"display": True, "period": 10},
)
print()
admm_weighted.solve()
x_weighted = postprocess(admm_weighted.x)
Iter Time Objective Prml Rsdl Dual Rsdl CG It CG Res
-----------------------------------------------------------------
0 1.64e+00 5.153e+06 4.966e+02 5.067e+04 10 9.466e-04
10 1.20e+01 3.204e+06 5.621e+01 4.171e+04 10 1.360e-04
20 2.20e+01 2.033e+06 5.652e+01 3.059e+04 10 1.623e-04
30 3.33e+01 1.434e+06 4.340e+01 2.151e+04 10 1.313e-04
40 4.62e+01 1.160e+06 3.656e+01 1.370e+04 10 1.152e-04
50 5.88e+01 1.068e+06 2.327e+01 7.357e+03 10 6.871e-05
60 7.15e+01 1.046e+06 1.174e+01 3.458e+03 10 3.205e-05
70 8.44e+01 1.041e+06 6.115e+00 1.424e+03 10 1.415e-05
80 9.71e+01 1.041e+06 2.780e+00 6.394e+02 9 8.665e-06
90 1.07e+02 1.041e+06 1.761e+00 4.559e+02 7 8.477e-06
99 1.15e+02 1.041e+06 1.375e+00 3.614e+02 6 8.812e-06
Show recovered images.
[9]:
def plot_recon(x, title, ax):
"""Plot an image with title indicating error metrics."""
kplt.imview(
x,
cmap="Blues",
title=f"{title}\nSNR: {metric.snr(x_gt, x):.2f} (dB), MAE: {metric.mae(x_gt, x):.3f}",
ax=ax,
)
fig, ax = kplt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))
kplt.imview(x_gt, cmap="Blues", title="Ground truth", ax=ax[0, 0])
plot_recon(x0, "FBP Reconstruction", ax=ax[0, 1])
plot_recon(x_unweighted, "Unweighted TV Reconstruction", ax=ax[1, 0])
plot_recon(x_weighted, "Weighted TV Reconstruction", ax=ax[1, 1])
for ax_ in ax.ravel():
ax_.set_xlim(64, 448)
ax_.set_ylim(64, 448)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="arbitrary units"
)
fig.show()