CT Training and Reconstructions with MoDL

This example demonstrates the training and application of a model-based deep learning (MoDL) architecture described in [1] applied to a CT reconstruction problem.

The source images are foam phantoms generated with xdesign.

A class scico.flax.MoDLNet implements the MoDL architecture, which solves the optimization problem

\[\mathrm{argmin}_{\mathbf{x}} \; \| A \mathbf{x} - \mathbf{y} \|_2^2 + \lambda \, \| \mathbf{x} - \mathrm{D}_w(\mathbf{x})\|_2^2 \;,\]

where \(A\) is a tomographic projector, \(\mathbf{y}\) is a set of sinograms, \(\mathrm{D}_w\) is the regularization (a denoiser), and \(\mathbf{x}\) is the set of reconstructed images. The MoDL abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the MoDL network and updates the prediction by solving

\[\mathbf{x}^{k+1} = (A^T A + \lambda \, I)^{-1} (A^T \mathbf{y} + \lambda \, \mathbf{z}^k) \;,\]

via conjugate gradient. In the expression, \(k\) is the index of the stage (iteration), \(\mathbf{z}^k = \mathrm{ResNet}(\mathbf{x}^{k})\) is the regularization (a denoiser implemented as a residual convolutional neural network), \(\mathbf{x}^k\) is the output of the previous stage, \(\lambda > 0\) is a learned regularization parameter, and \(I\) is the identity operator. The output of the final stage is the set of reconstructed images.

[1]:
import os
from functools import partial
from time import time

import numpy as np

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform2D
plot.config_notebook_plotting()

Prepare parallel processing. Set an arbitrary processor count (only applies if GPU is not available).

[2]:
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)
Platform:  gpu

Read data from cache or generate if not available.

[3]:
N = 256  # phantom size
train_nimg = 536  # number of training images
test_nimg = 64  # number of testing images
nimg = train_nimg + test_nimg
n_projection = 45  # CT views

trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)
Data read from path       :   ~/.cache/scico/examples/data
Set --training--          :   Size:   536
Set --testing --          :   Size:   64
Data range --images  --   :    Min:   0.00, Max: 1.00
Data range --sinogram--   :    Min:   0.00, Max: 0.67
Data range --FBP     --   :    Min:   0.00, Max: 1.00

Build CT projection operator.

[4]:
angles = np.linspace(0, np.pi, n_projection)  # evenly spaced projection angles
A = XRayTransform2D(
    input_shape=(N, N),
    det_spacing=1,
    det_count=N,
    angles=angles,
)  # CT projection operator
A = (1.0 / N) * A  # normalized

Build training and testing structures. Inputs are the sinograms and outputs are the original generated foams. Keep training and testing partitions.

[5]:
numtr = 100
numtt = 16
train_ds = {"image": trdt["sino"][:numtr], "label": trdt["img"][:numtr]}
test_ds = {"image": ttdt["sino"][:numtt], "label": ttdt["img"][:numtt]}

Define configuration dictionary for model and training loop.

Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the MoDL model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. The iterations used for the conjugate gradient (CG) solver can also be specified. Better performance may be obtained by increasing depth, block depth, number of filters, CG iterations, or training epochs, but may require longer training times.

[6]:
# model configuration
model_conf = {
    "depth": 10,
    "num_filters": 64,
    "block_depth": 4,
    "cg_iter_1": 3,
    "cg_iter_2": 8,
}
# training configuration
train_conf: sflax.ConfigDict = {
    "seed": 12345,
    "opt_type": "SGD",
    "momentum": 0.9,
    "batch_size": 16,
    "num_epochs": 20,
    "base_learning_rate": 1e-2,
    "warmup_epochs": 0,
    "log_every_steps": 40,
    "log": True,
    "checkpointing": True,
}

Construct functionality for ensuring that the learned regularization parameter is always positive.

[7]:
lmbdatrav = construct_traversal("lmbda")  # select lmbda parameters in model
lmbdapos = partial(
    clip_positive,  # apply this function
    traversal=lmbdatrav,  # to lmbda parameters in model
    minval=5e-4,
)

Print configuration of distributed run.

[8]:
print(f"\nJAX process: {jax.process_index()}{' / '}{jax.process_count()}")
print(f"JAX local devices: {jax.local_devices()}\n")

JAX process: 0 / 1
JAX local devices: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]

Check for iterated trained model. If not found, construct MoDLNet model, using only one iteration (depth) in model and few CG iterations for faster intialization. Run first stage (initialization) training loop followed by a second stage (depth iterations) training loop.

[9]:
channels = train_ds["image"].shape[-1]
workdir2 = os.path.join(
    os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out", "iterated"
)

stats_object_ini = None
stats_object = None

checkpoint_files = []
for dirpath, dirnames, filenames in os.walk(workdir2):
    checkpoint_files = [fn for fn in filenames]

if len(checkpoint_files) > 0:
    model = sflax.MoDLNet(
        operator=A,
        depth=model_conf["depth"],
        channels=channels,
        num_filters=model_conf["num_filters"],
        block_depth=model_conf["block_depth"],
        cg_iter=model_conf["cg_iter_2"],
    )

    train_conf["post_lst"] = [lmbdapos]
    # Parameters for 2nd stage
    train_conf["workdir"] = workdir2
    train_conf["opt_type"] = "ADAM"
    train_conf["num_epochs"] = 150
    # Construct training object
    trainer = sflax.BasicFlaxTrainer(
        train_conf,
        model,
        train_ds,
        test_ds,
    )
    start_time = time()
    modvar, stats_object = trainer.train()
    time_train = time() - start_time
    time_init = 0.0
    epochs_init = 0
else:
    # One iteration (depth) in model and few CG iterations
    model = sflax.MoDLNet(
        operator=A,
        depth=1,
        channels=channels,
        num_filters=model_conf["num_filters"],
        block_depth=model_conf["block_depth"],
        cg_iter=model_conf["cg_iter_1"],
    )
    # First stage: initialization training loop.
    workdir1 = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")
    train_conf["workdir"] = workdir1
    train_conf["post_lst"] = [lmbdapos]
    # Construct training object
    trainer = sflax.BasicFlaxTrainer(
        train_conf,
        model,
        train_ds,
        test_ds,
    )

    start_time = time()
    modvar, stats_object_ini = trainer.train()
    time_init = time() - start_time
    epochs_init = train_conf["num_epochs"]

    print(
        f"{'MoDLNet init':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}{'':3s}"
        f"{'time[s]:':21s}{time_init:>7.2f}"
    )

    # Second stage: depth iterations training loop.
    model.depth = model_conf["depth"]
    model.cg_iter = model_conf["cg_iter_2"]
    train_conf["opt_type"] = "ADAM"
    train_conf["num_epochs"] = 150
    train_conf["workdir"] = workdir2
    # Construct training object, include current model parameters
    trainer = sflax.BasicFlaxTrainer(
        train_conf,
        model,
        train_ds,
        test_ds,
        variables0=modvar,
    )
    start_time = time()
    modvar, stats_object = trainer.train()
    time_train = time() - start_time
channels: 1   training signals: 100   testing signals: 16   signal size: 256

Network Structure:
+------------------------------------------+----------------+--------+----------+--------+
| Name                                     | Shape          | Size   | Mean     | Std    |
+------------------------------------------+----------------+--------+----------+--------+
| ResNet_0/BatchNorm_0/bias                | (1,)           | 1      | 0.0      | 0.0    |
| ResNet_0/BatchNorm_0/scale               | (1,)           | 1      | 1.0      | 0.0    |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0      | 0.0    |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0      | 0.0    |
| ResNet_0/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | 0.0023   | 0.058  |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0      | 0.0    |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0      | 0.0    |
| ResNet_0/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000156 | 0.0417 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0      | 0.0    |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0      | 0.0    |
| ResNet_0/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000153 | 0.0418 |
| ResNet_0/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | -0.00192 | 0.0585 |
| lmbda                                    | (1,)           | 1      | 0.5      | 0.0    |
+------------------------------------------+----------------+--------+----------+--------+
Total weights: 75,267

Batch Normalization:
+-----------------------------------------+-------+------+------+-----+
| Name                                    | Shape | Size | Mean | Std |
+-----------------------------------------+-------+------+------+-----+
| ResNet_0/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ResNet_0/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
+-----------------------------------------+-------+------+------+-----+
Total weights: 386

Initial compilation, which might take some time ...
Initial compilation completed.

Epoch  Time      Train_LR  Train_Loss  Train_SNR  Eval_Loss  Eval_SNR
---------------------------------------------------------------------
    6  1.52e+01  0.010000    0.093700      -0.59   0.040131      0.57
   13  2.59e+01  0.010000    0.018025       4.06   0.040379      0.54
   19  3.59e+01  0.010000    0.014100       5.12   0.040495      0.53
MoDLNet init      epochs:   20   time[s]:               37.20
channels: 1   training signals: 100   testing signals: 16   signal size: 256

Network Structure:
+------------------------------------------+----------------+--------+-----------+---------+
| Name                                     | Shape          | Size   | Mean      | Std     |
+------------------------------------------+----------------+--------+-----------+---------+
| ResNet_0/BatchNorm_0/bias                | (1,)           | 1      | 0.255     | 0.0     |
| ResNet_0/BatchNorm_0/scale               | (1,)           | 1      | 0.242     | 0.0     |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | -0.000882 | 0.0121  |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.00539 |
| ResNet_0/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | 0.000831  | 0.0676  |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 4.59e-05  | 0.00228 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.00193 |
| ResNet_0/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000137  | 0.042   |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | -4.2e-05  | 0.00135 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.00142 |
| ResNet_0/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000154  | 0.0418  |
| ResNet_0/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | 0.000495  | 0.0591  |
| lmbda                                    | (1,)           | 1      | 0.413     | 0.0     |
+------------------------------------------+----------------+--------+-----------+---------+
Total weights: 75,267

Batch Normalization:
+-----------------------------------------+-------+------+----------+----------+
| Name                                    | Shape | Size | Mean     | Std      |
+-----------------------------------------+-------+------+----------+----------+
| ResNet_0/BatchNorm_0/mean               | (1,)  | 1    | 0.0203   | 0.0      |
| ResNet_0/BatchNorm_0/var                | (1,)  | 1    | 2.31     | 0.0      |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.000316 | 0.0131   |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 0.299    | 2.98e-05 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0049   | 0.231    |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 0.876    | 3.24     |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.027    | 0.296    |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 0.605    | 0.277    |
+-----------------------------------------+-------+------+----------+----------+
Total weights: 386

Initial compilation, which might take some time ...
Initial compilation completed.

Epoch  Time      Train_LR  Train_Loss  Train_SNR  Eval_Loss  Eval_SNR
---------------------------------------------------------------------
    6  2.34e+02  0.010000    0.211906      -3.28   0.049724     -0.36
   13  4.55e+02  0.010000    0.032641       1.46   0.030719      1.73
   19  6.71e+02  0.010000    0.026021       2.46   0.029391      1.92
   26  8.87e+02  0.010000    0.016752       4.41   0.012768      5.54
   33  1.10e+03  0.010000    0.009207       7.00   0.008295      7.42
   39  1.32e+03  0.010000    0.005996       8.85   0.005808      8.96
   46  1.54e+03  0.010000    0.004503      10.06   0.005296      9.36
   53  1.76e+03  0.010000    0.004123      10.45   0.004669      9.91
   59  1.97e+03  0.010000    0.003890      10.70   0.004452     10.12
   66  2.19e+03  0.010000    0.003713      10.90   0.004379     10.19
   73  2.41e+03  0.010000    0.003565      11.08   0.004277     10.29
   79  2.63e+03  0.010000    0.003403      11.28   0.004176     10.40
   86  2.84e+03  0.010000    0.003266      11.46   0.003938     10.65
   93  3.06e+03  0.010000    0.003160      11.60   0.003787     10.82
   99  3.28e+03  0.010000    0.003015      11.81   0.003691     10.93
  106  3.49e+03  0.010000    0.002991      11.84   0.003715     10.90
  113  3.71e+03  0.010000    0.002859      12.04   0.003571     11.08
  119  3.93e+03  0.010000    0.002776      12.16   0.003434     11.25
  126  4.14e+03  0.010000    0.002674      12.33   0.003446     11.23
  133  4.36e+03  0.010000    0.002603      12.44   0.003261     11.47
  139  4.58e+03  0.010000    0.002567      12.51   0.003236     11.50
  146  4.79e+03  0.010000    0.002517      12.59   0.003363     11.34

Evaluate on testing data.

[10]:
del train_ds["image"]
del train_ds["label"]

fmap = sflax.FlaxMap(model, modvar)
del model, modvar

maxn = numtt
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = np.clip(output, a_min=0, a_max=1.0)

Evaluate trained model in terms of reconstruction time and data fidelity.

[11]:
total_epochs = epochs_init + train_conf["num_epochs"]
total_time_train = time_init + time_train
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
    f"{'MoDLNet training':18s}{'epochs:':2s}{total_epochs:>5d}{'':21s}"
    f"{'time[s]:':10s}{total_time_train:>7.2f}"
)
print(
    f"{'MoDLNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
    f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)
MoDLNet training  epochs:  170                     time[s]:  4940.85
MoDLNet testing   SNR: 11.44 dB   PSNR: 21.83 dB   time[s]:     9.66

Plot comparison.

[12]:
np.random.seed(123)
indx = np.random.randint(0, high=maxn)

fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0])
plot.imview(
    test_ds["image"][indx, ..., 0],
    title="Sinogram",
    cbar=None,
    fig=fig,
    ax=ax[1],
)
plot.imview(
    output[indx, ..., 0],
    title="MoDLNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f"
    % (
        metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
        metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
    ),
    fig=fig,
    ax=ax[2],
)
divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
../_images/examples_ct_astra_modl_train_foam2_23_0.png

Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint).

[13]:
if stats_object is not None and len(stats_object.iterations) > 0:
    hist = stats_object.history(transpose=True)
    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
    plot.plot(
        np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
        x=hist.Epoch,
        ptyp="semilogy",
        title="Loss function",
        xlbl="Epoch",
        ylbl="Loss value",
        lgnd=("Train", "Test"),
        fig=fig,
        ax=ax[0],
    )
    plot.plot(
        np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
        x=hist.Epoch,
        title="Metric",
        xlbl="Epoch",
        ylbl="SNR (dB)",
        lgnd=("Train", "Test"),
        fig=fig,
        ax=ax[1],
    )
    fig.show()

# Stats for initialization loop
if stats_object_ini is not None and len(stats_object_ini.iterations) > 0:
    hist = stats_object_ini.history(transpose=True)
    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
    plot.plot(
        np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
        ptyp="semilogy",
        title="Loss function - Initialization",
        xlbl="Epoch",
        ylbl="Loss value",
        lgnd=("Train", "Test"),
        fig=fig,
        ax=ax[0],
    )
    plot.plot(
        np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
        title="Metric - Initialization",
        xlbl="Epoch",
        ylbl="SNR (dB)",
        lgnd=("Train", "Test"),
        fig=fig,
        ax=ax[1],
    )
    fig.show()
../_images/examples_ct_astra_modl_train_foam2_25_0.png
../_images/examples_ct_astra_modl_train_foam2_25_1.png