TV-Regularized Low-Dose CT Reconstruction

This example demonstrates solution of a low-dose CT reconstruction problem with isotropic total variation (TV) regularization

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_W^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,\]

where \(A\) is the X-ray transform (the CT forward projection), \(\mathbf{y}\) is the sinogram, the norm weighting \(W\) is chosen so that the weighted norm is an approximation to the Poisson negative log likelihood [51], \(C\) is a 2D finite difference operator, and \(\mathbf{x}\) is the reconstructed image.

[1]:
import warnings

import numpy as np

import komplot as kplt
from xdesign import Soil, discrete_phantom

import scico.numpy as snp
from scico import functional, linop, loss, metric
from scico.linop.xray.astra import XRayTransform2D
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
kplt.config_notebook_plotting()

Create a ground truth image.

[2]:
N = 512  # phantom size
np.random.seed(0)
with warnings.catch_warnings():
    warnings.simplefilter("ignore", category=RuntimeWarning)
    x_gt = discrete_phantom(Soil(porosity=0.60), size=384)
x_gt = np.ascontiguousarray(np.pad(x_gt, (64, 64)))
x_gt = np.clip(x_gt, 0, np.inf)  # clip to positive values
x_gt = snp.array(x_gt)  # convert to jax type

Configure CT projection operator and generate synthetic measurements.

[3]:
n_projection = 360  # number of projections
Io = 1e3  # source flux
𝛼 = 1e-2  # attenuation coefficient
angles = np.linspace(0, 2 * np.pi, n_projection, endpoint=False)  # evenly spaced projection angles
A = XRayTransform2D(x_gt.shape, N, 1.0, angles)  # CT projection operator
y_c = A @ x_gt  # sinogram

Add Poisson noise to projections according to

\[\mathrm{counts} \sim \mathrm{Poi}\left(I_0 \exp (- \alpha A \mathbf{x} ) \right)\]
\[\mathbf{y} = - \frac{1}{\alpha} \log\left(\mathrm{counts} / I_0\right) \;.\]

We use the NumPy random functionality so we can generate using 64-bit numbers.

[4]:
counts = np.random.poisson(Io * snp.exp(-𝛼 * A @ x_gt))
counts = np.clip(counts, a_min=1, a_max=np.inf)  # replace any 0s count with 1
y = -1 / 𝛼 * np.log(counts / Io)
y = snp.array(y)  # convert back to float32 as a jax array

Set up post processing. For this example, we clip all reconstructions to the range of the ground truth.

[5]:
def postprocess(x):
    return snp.clip(x, 0, snp.max(x_gt))

Compute an FBP reconstruction as an initial guess.

[6]:
x0 = postprocess(A.fbp(y))

Set up and solve the un-weighted reconstruction problem

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} \;.\]
[7]:
# Note that rho and lambda were selected via a parameter sweep (not
# shown here).
ρ = 2.5e3  # ADMM penalty parameter
lambda_unweighted = 3e2  # regularization strength
maxiter = 100  # number of ADMM iterations
cg_tol = 1e-5  # CG relative tolerance
cg_maxiter = 10  # maximum CG iterations per ADMM iteration
f = loss.SquaredL2Loss(y=y, A=A)
admm_unweighted = ADMM(
    f=f,
    g_list=[lambda_unweighted * functional.L21Norm()],
    C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
    rho_list=[ρ],
    x0=x0,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
    itstat_options={"display": True, "period": 10},
)
print(f"Solving on {device_info()}\n")
admm_unweighted.solve()
x_unweighted = postprocess(admm_unweighted.x)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Objective  Prml Rsdl  Dual Rsdl  CG It  CG Res
-----------------------------------------------------------------
   0  2.98e+00  2.443e+07  2.848e+03  3.000e+05     10  2.274e-04
  10  9.69e+00  7.813e+06  2.190e+02  2.417e+04     10  4.892e-05
  20  1.54e+01  7.833e+06  3.956e+01  2.157e+03      6  9.729e-06
  30  1.97e+01  7.847e+06  2.192e+01  9.447e+02      5  8.738e-06
  40  2.36e+01  7.854e+06  1.531e+01  5.511e+02      2  9.673e-06
  50  2.68e+01  7.857e+06  1.167e+01  3.806e+02      3  7.047e-06
  60  3.02e+01  7.860e+06  9.566e+00  1.990e+02      3  7.704e-06
  70  3.36e+01  7.862e+06  7.893e+00  1.271e+02      2  9.199e-06
  80  3.68e+01  7.863e+06  6.747e+00  1.383e+02      3  5.676e-06
  90  3.94e+01  7.864e+06  5.912e+00  2.206e+01      0  9.402e-06
  99  4.19e+01  7.864e+06  5.361e+00  1.660e+01      0  9.550e-06

Set up and solve the weighted reconstruction problem

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| \mathbf{y} - A \mathbf{x} \|_W^2 + \lambda \| C \mathbf{x} \|_{2,1} \;,\]

where

\[W = \mathrm{diag}( \mathrm{counts} / I_0 ) \;.\]

The data fidelity term in this formulation follows [51] (9) except for the scaling by \(I_0\), which we use to maintain balance between the data and regularization terms if \(I_0\) changes.

[8]:
lambda_weighted = 5e1
weights = snp.array(counts / Io)
f = loss.SquaredL2Loss(y=y, A=A, W=linop.Diagonal(weights))
admm_weighted = ADMM(
    f=f,
    g_list=[lambda_weighted * functional.L21Norm()],
    C_list=[linop.FiniteDifference(x_gt.shape, append=0)],
    rho_list=[ρ],
    maxiter=maxiter,
    x0=x0,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": cg_tol, "maxiter": cg_maxiter}),
    itstat_options={"display": True, "period": 10},
)
print()
admm_weighted.solve()
x_weighted = postprocess(admm_weighted.x)

Iter  Time      Objective  Prml Rsdl  Dual Rsdl  CG It  CG Res
-----------------------------------------------------------------
   0  1.64e+00  5.153e+06  4.966e+02  5.067e+04     10  9.466e-04
  10  1.20e+01  3.204e+06  5.621e+01  4.171e+04     10  1.360e-04
  20  2.20e+01  2.033e+06  5.652e+01  3.059e+04     10  1.623e-04
  30  3.33e+01  1.434e+06  4.340e+01  2.151e+04     10  1.313e-04
  40  4.62e+01  1.160e+06  3.656e+01  1.370e+04     10  1.152e-04
  50  5.88e+01  1.068e+06  2.327e+01  7.357e+03     10  6.871e-05
  60  7.15e+01  1.046e+06  1.174e+01  3.458e+03     10  3.205e-05
  70  8.44e+01  1.041e+06  6.115e+00  1.424e+03     10  1.415e-05
  80  9.71e+01  1.041e+06  2.780e+00  6.394e+02      9  8.665e-06
  90  1.07e+02  1.041e+06  1.761e+00  4.559e+02      7  8.477e-06
  99  1.15e+02  1.041e+06  1.375e+00  3.614e+02      6  8.812e-06

Show recovered images.

[9]:
def plot_recon(x, title, ax):
    """Plot an image with title indicating error metrics."""
    kplt.imview(
        x,
        cmap="Blues",
        title=f"{title}\nSNR: {metric.snr(x_gt, x):.2f} (dB), MAE: {metric.mae(x_gt, x):.3f}",
        ax=ax,
    )


fig, ax = kplt.subplots(nrows=2, ncols=2, sharex=True, sharey=True, figsize=(11, 10))
kplt.imview(x_gt, cmap="Blues", title="Ground truth", ax=ax[0, 0])
plot_recon(x0, "FBP Reconstruction", ax=ax[0, 1])
plot_recon(x_unweighted, "Unweighted TV Reconstruction", ax=ax[1, 0])
plot_recon(x_weighted, "Weighted TV Reconstruction", ax=ax[1, 1])
for ax_ in ax.ravel():
    ax_.set_xlim(64, 448)
    ax_.set_ylim(64, 448)
fig.subplots_adjust(left=0.1, right=0.99, top=0.95, bottom=0.05, wspace=0.2, hspace=0.01)
fig.colorbar(
    ax[0, 0].get_images()[0], ax=ax, location="right", shrink=0.9, pad=0.05, label="arbitrary units"
)
fig.show()
../_images/examples_ct_astra_weighted_tv_admm_17_0.png