CT Training and Reconstructions with MoDL#
This example demonstrates the training and application of a model-based deep learning (MoDL) architecture described in [1] applied to a CT reconstruction problem.
The source images are foam phantoms generated with xdesign.
A class scico.flax.MoDLNet implements the MoDL architecture, which solves the optimization problem
where \(A\) is a tomographic projector, \(\mathbf{y}\) is a set of sinograms, \(\mathrm{D}_w\) is the regularization (a denoiser), and \(\mathbf{x}\) is the set of reconstructed images. The MoDL abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the MoDL network and updates the prediction by solving
via conjugate gradient. In the expression, \(k\) is the index of the stage (iteration), \(\mathbf{z}^k = \mathrm{ResNet}(\mathbf{x}^{k})\) is the regularization (a denoiser implemented as a residual convolutional neural network), \(\mathbf{x}^k\) is the output of the previous stage, \(\lambda > 0\) is a learned regularization parameter, and \(I\) is the identity operator. The output of the final stage is the set of reconstructed images.
[1]:
import os
from functools import partial
from time import time
import numpy as np
import jax
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform
plot.config_notebook_plotting()
Prepare parallel processing. Set an arbitrary processor count (only applies if GPU is not available).
[2]:
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)
Platform: gpu
Read data from cache or generate if not available.
[3]:
N = 256 # phantom size
train_nimg = 536 # number of training images
test_nimg = 64 # number of testing images
nimg = train_nimg + test_nimg
n_projection = 45 # CT views
trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)
Data read from path : ~/.cache/scico/examples/data
Set --training-- : Size: 536
Set --testing -- : Size: 64
Data range --images -- : Min: 0.00, Max: 1.00
Data range --sinogram-- : Min: 0.00, Max: 0.67
Data range --FBP -- : Min: 0.00, Max: 1.00
Build CT projection operator.
[4]:
angles = np.linspace(0, np.pi, n_projection) # evenly spaced projection angles
A = XRayTransform(
input_shape=(N, N),
detector_spacing=1,
det_count=N,
angles=angles,
) # CT projection operator
A = (1.0 / N) * A # normalized
Build training and testing structures. Inputs are the sinograms and outputs are the original generated foams. Keep training and testing partitions.
[5]:
numtr = 100
numtt = 16
train_ds = {"image": trdt["sino"][:numtr], "label": trdt["img"][:numtr]}
test_ds = {"image": ttdt["sino"][:numtt], "label": ttdt["img"][:numtt]}
Define configuration dictionary for model and training loop.
Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the MoDL model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. The iterations used for the conjugate gradient (CG) solver can also be specified. Better performance may be obtained by increasing depth, block depth, number of filters, CG iterations, or training epochs, but may require longer training times.
[6]:
# model configuration
model_conf = {
"depth": 10,
"num_filters": 64,
"block_depth": 4,
"cg_iter_1": 3,
"cg_iter_2": 8,
}
# training configuration
train_conf: sflax.ConfigDict = {
"seed": 12345,
"opt_type": "SGD",
"momentum": 0.9,
"batch_size": 16,
"num_epochs": 20,
"base_learning_rate": 1e-2,
"warmup_epochs": 0,
"log_every_steps": 40,
"log": True,
"checkpointing": True,
}
Construct functionality for making sure that the learned regularization parameter is always positive.
[7]:
lmbdatrav = construct_traversal("lmbda") # select lmbda parameters in model
lmbdapos = partial(
clip_positive, # apply this function
traversal=lmbdatrav, # to lmbda parameters in model
minval=5e-4,
)
Print configuration of distributed run.
[8]:
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
JAX process: 0 / 1
JAX local devices: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]
Check for iterated trained model. If not found, construct MoDLNet model, using only one iteration (depth) in model and few CG iterations for faster intialization. Run first stage (initialization) training loop followed by a second stage (depth iterations) training loop.
[9]:
channels = train_ds["image"].shape[-1]
workdir2 = os.path.join(
os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out", "iterated"
)
stats_object_ini = None
stats_object = None
checkpoint_files = []
for dirpath, dirnames, filenames in os.walk(workdir2):
checkpoint_files = [fn for fn in filenames]
if len(checkpoint_files) > 0:
model = sflax.MoDLNet(
operator=A,
depth=model_conf["depth"],
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
cg_iter=model_conf["cg_iter_2"],
)
train_conf["post_lst"] = [lmbdapos]
# Parameters for 2nd stage
train_conf["workdir"] = workdir2
train_conf["opt_type"] = "ADAM"
train_conf["num_epochs"] = 150
# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)
start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time
time_init = 0.0
epochs_init = 0
else:
# One iteration (depth) in model and few CG iterations
model = sflax.MoDLNet(
operator=A,
depth=1,
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
cg_iter=model_conf["cg_iter_1"],
)
# First stage: initialization training loop.
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "modl_ct_out")
train_conf["workdir"] = workdir
train_conf["post_lst"] = [lmbdapos]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)
start_time = time()
modvar, stats_object_ini = trainer.train()
time_init = time() - start_time
epochs_init = train_conf["num_epochs"]
print(
f"{'MoDLNet init':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}{'':3s}"
f"{'time[s]:':21s}{time_init:>7.2f}"
)
# Second stage: depth iterations training loop.
model.depth = model_conf["depth"]
model.cg_iter = model_conf["cg_iter_2"]
train_conf["opt_type"] = "ADAM"
train_conf["num_epochs"] = 150
train_conf["workdir"] = workdir2
# Construct training object, include current model parameters
trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
variables0=modvar,
)
start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time
Channels: 1, training signals: 100, testing signals: 16, signal size: 256
+------------------------------------------+----------------+--------+----------+--------+
| Name | Shape | Size | Mean | Std |
+------------------------------------------+----------------+--------+----------+--------+
| ResNet_0/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 |
| ResNet_0/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 |
| ResNet_0/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | 0.0023 | 0.058 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 |
| ResNet_0/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000156 | 0.0417 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 |
| ResNet_0/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000153 | 0.0418 |
| ResNet_0/Conv_0/kernel | (3, 3, 64, 1) | 576 | -0.00192 | 0.0585 |
| lmbda | (1,) | 1 | 0.5 | 0.0 |
+------------------------------------------+----------------+--------+----------+--------+
Total: 75,267
+-----------------------------------------+-------+------+------+-----+
| Name | Shape | Size | Mean | Std |
+-----------------------------------------+-------+------+------+-----+
| ResNet_0/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 |
| ResNet_0/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 |
+-----------------------------------------+-------+------+------+-----+
Total: 386
Initial compilation, this might take some minutes...
Initial compilation completed.
Epoch Time Train_LR Train_Loss Train_SNR Eval_Loss Eval_SNR
---------------------------------------------------------------------
6 1.93e+01 0.010000 0.093700 -0.59 0.040131 0.57
13 3.31e+01 0.010000 0.018025 4.06 0.040379 0.54
19 4.60e+01 0.010000 0.014087 5.12 0.040489 0.53
MoDLNet init epochs: 20 time[s]: 47.24
Channels: 1, training signals: 100, testing signals: 16, signal size: 256
+------------------------------------------+----------------+--------+-----------+---------+
| Name | Shape | Size | Mean | Std |
+------------------------------------------+----------------+--------+-----------+---------+
| ResNet_0/BatchNorm_0/bias | (1,) | 1 | 0.255 | 0.0 |
| ResNet_0/BatchNorm_0/scale | (1,) | 1 | 0.242 | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | -0.000882 | 0.0121 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.00539 |
| ResNet_0/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | 0.000832 | 0.0676 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 4.6e-05 | 0.00228 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.00193 |
| ResNet_0/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000137 | 0.042 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/bias | (64,) | 64 | -4.19e-05 | 0.00135 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.00141 |
| ResNet_0/ConvBNBlock_2/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 0.000154 | 0.0418 |
| ResNet_0/Conv_0/kernel | (3, 3, 64, 1) | 576 | 0.000494 | 0.0591 |
| lmbda | (1,) | 1 | 0.413 | 0.0 |
+------------------------------------------+----------------+--------+-----------+---------+
Total: 75,267
+-----------------------------------------+-------+------+----------+----------+
| Name | Shape | Size | Mean | Std |
+-----------------------------------------+-------+------+----------+----------+
| ResNet_0/BatchNorm_0/mean | (1,) | 1 | 0.0203 | 0.0 |
| ResNet_0/BatchNorm_0/var | (1,) | 1 | 2.31 | 0.0 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.000317 | 0.0131 |
| ResNet_0/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 0.299 | 2.98e-05 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0049 | 0.231 |
| ResNet_0/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 0.876 | 3.24 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64 | 0.027 | 0.296 |
| ResNet_0/ConvBNBlock_2/BatchNorm_0/var | (64,) | 64 | 0.605 | 0.277 |
+-----------------------------------------+-------+------+----------+----------+
Total: 386
Initial compilation, this might take some minutes...
Initial compilation completed.
Epoch Time Train_LR Train_Loss Train_SNR Eval_Loss Eval_SNR
---------------------------------------------------------------------
6 3.05e+02 0.010000 0.217458 -3.16 0.051299 -0.50
13 5.97e+02 0.010000 0.033291 1.37 0.032434 1.49
19 8.96e+02 0.010000 0.029017 1.97 0.026394 2.39
26 1.19e+03 0.010000 0.021499 3.34 0.016957 4.31
33 1.48e+03 0.010000 0.009802 6.77 0.008415 7.35
39 1.77e+03 0.010000 0.005998 8.85 0.006215 8.67
46 2.05e+03 0.010000 0.004526 10.04 0.005586 9.13
53 2.34e+03 0.010000 0.004128 10.44 0.005214 9.43
59 2.63e+03 0.010000 0.003889 10.70 0.004838 9.76
66 2.92e+03 0.010000 0.003687 10.93 0.004537 10.04
73 3.22e+03 0.010000 0.003510 11.14 0.004420 10.15
79 3.51e+03 0.010000 0.003363 11.33 0.004397 10.17
86 3.80e+03 0.010000 0.003225 11.52 0.004064 10.51
93 4.09e+03 0.010000 0.003087 11.70 0.003745 10.87
99 4.38e+03 0.010000 0.002979 11.86 0.003645 10.99
106 4.68e+03 0.010000 0.002872 12.01 0.003694 10.93
113 4.98e+03 0.010000 0.002779 12.16 0.003558 11.09
119 5.27e+03 0.010000 0.002670 12.33 0.003424 11.26
126 5.56e+03 0.010000 0.002594 12.46 0.003442 11.24
133 5.85e+03 0.010000 0.002530 12.56 0.003408 11.28
139 6.15e+03 0.010000 0.002493 12.64 0.003175 11.59
146 6.45e+03 0.010000 0.002447 12.71 0.003118 11.67
Evaluate on testing data.
[10]:
del train_ds["image"]
del train_ds["label"]
fmap = sflax.FlaxMap(model, modvar)
del model, modvar
maxn = numtt
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = np.clip(output, a_min=0, a_max=1.0)
Evaluate trained model in terms of reconstruction time and data fidelity.
[11]:
total_epochs = epochs_init + train_conf["num_epochs"]
total_time_train = time_init + time_train
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'MoDLNet training':18s}{'epochs:':2s}{total_epochs:>5d}{'':21s}"
f"{'time[s]:':10s}{total_time_train:>7.2f}"
)
print(
f"{'MoDLNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)
MoDLNet training epochs: 170 time[s]: 6644.34
MoDLNet testing SNR: 11.52 dB PSNR: 21.90 dB time[s]: 12.45
Plot comparison.
[12]:
np.random.seed(123)
indx = np.random.randint(0, high=maxn)
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0])
plot.imview(
test_ds["image"][indx, ..., 0],
title="Sinogram",
cbar=None,
fig=fig,
ax=ax[1],
)
plot.imview(
output[indx, ..., 0],
title="MoDLNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f"
% (
metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
),
fig=fig,
ax=ax[2],
)
divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint).
[13]:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function",
xlbl="Epoch",
ylbl="Loss value",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[0],
)
plot.plot(
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric",
xlbl="Epoch",
ylbl="SNR (dB)",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[1],
)
fig.show()
# Stats for initialization loop
if stats_object_ini is not None and len(stats_object_ini.iterations) > 0:
hist = stats_object_ini.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
ptyp="semilogy",
title="Loss function - Initialization",
xlbl="Epoch",
ylbl="Loss value",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[0],
)
plot.plot(
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
title="Metric - Initialization",
xlbl="Epoch",
ylbl="SNR (dB)",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[1],
)
fig.show()