CT Training and Reconstructions with ODP#

This example demonstrates the training of the unrolled optimization with deep priors (ODP) gradient descent architecture described in [19] applied to a CT reconstruction problem.

The source images are foam phantoms generated with xdesign.

A class scico.flax.ODPNet implements the ODP architecture, which solves the optimization problem

\[\mathrm{argmin}_{\mathbf{x}} \; \| A \mathbf{x} - \mathbf{y} \|_2^2 + r(\mathbf{x}) \;,\]

where \(A\) is a tomographic projector, \(\mathbf{y}\) is a set of sinograms, \(r\) is a regularizer and \(\mathbf{x}\) is the set of reconstructed images. The ODP, gradient descent architecture, abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the ODP network and updates the prediction by solving

\[\mathbf{x}^{k+1} = \mathrm{argmin}_{\mathbf{x}} \; \alpha_k \| A \mathbf{x} - \mathbf{y} \|_2^2 + \frac{1}{2} \| \mathbf{x} - \mathbf{x}^k - \mathbf{x}^{k+1/2} \|_2^2 \;,\]

which for the CT problem, using gradient descent, corresponds to

\[\mathbf{x}^{k+1} = \mathbf{x}^k + \mathbf{x}^{k+1/2} - \alpha_k \, A^T \, (A \mathbf{x}^k - \mathbf{y}) \;,\]

where \(k\) is the index of the stage (iteration), \(\mathbf{x}^k + \mathbf{x}^{k+1/2} = \mathrm{ResNet}(\mathbf{x}^{k})\) is the regularization (implemented as a residual convolutional neural network), \(\mathbf{x}^k\) is the output of the previous stage and \(\alpha_k > 0\) is a learned stage-wise parameter weighting the contribution of the fidelity term. The output of the final stage is the set of reconstructed images.

[1]:
import os
from functools import partial
from time import time

import numpy as np

import jax

from mpl_toolkits.axes_grid1 import make_axes_locatable

from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_ct_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop.xray.astra import XRayTransform
plot.config_notebook_plotting()

Prepare parallel processing. Set an arbitrary processor count (only applies if GPU is not available).

[2]:
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)
Platform:  gpu

Read data from cache or generate if not available.

[3]:
N = 256  # phantom size
train_nimg = 536  # number of training images
test_nimg = 64  # number of testing images
nimg = train_nimg + test_nimg
n_projection = 45  # CT views

trdt, ttdt = load_ct_data(train_nimg, test_nimg, N, n_projection, verbose=True)
Data read from path       :   ~/.cache/scico/examples/data
Set --training--          :   Size:   536
Set --testing --          :   Size:   64
Data range --images  --   :    Min:   0.00, Max: 1.00
Data range --sinogram--   :    Min:   0.00, Max: 0.67
Data range --FBP     --   :    Min:   0.00, Max: 1.00

Build CT projection operator.

[4]:
angles = np.linspace(0, np.pi, n_projection)  # evenly spaced projection angles
A = XRayTransform(
    input_shape=(N, N),
    detector_spacing=1,
    det_count=N,
    angles=angles,
)  # CT projection operator
A = (1.0 / N) * A  # normalized

Build training and testing structures. Inputs are the sinograms and outpus are the original generated foams. Keep training and testing partitions.

[5]:
numtr = 320
numtt = 32
train_ds = {"image": trdt["sino"][:numtr], "label": trdt["img"][:numtr]}
test_ds = {"image": ttdt["sino"][:numtt], "label": ttdt["img"][:numtt]}

Define configuration dictionary for model and training loop.

Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the MoDL model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. The iterations used for the conjugate gradient (CG) solver can also be specified. Better performance may be obtained by increasing depth, block depth, number of filters, CG iterations, or training epochs, but may require longer training times.

[6]:
# model configuration
model_conf = {
    "depth": 8,
    "num_filters": 64,
    "block_depth": 6,
}
# training configuration
train_conf: sflax.ConfigDict = {
    "seed": 1234,
    "opt_type": "ADAM",
    "batch_size": 16,
    "num_epochs": 200,
    "base_learning_rate": 1e-3,
    "warmup_epochs": 0,
    "log_every_steps": 160,
    "log": True,
    "checkpointing": True,
}

Construct functionality for making sure that the learned fidelity weight parameter is always positive.

[7]:
alphatrav = construct_traversal("alpha")  # select alpha parameters in model
alphapost = partial(
    clip_positive,  # apply this function
    traversal=alphatrav,  # to alpha parameters in model
    minval=1e-3,
)

Print configuration of distributed run.

[8]:
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
JAX process: 0 / 1
JAX local devices: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]

Construct ODPNet model.

[9]:
channels = train_ds["image"].shape[-1]
model = sflax.ODPNet(
    operator=A,
    depth=model_conf["depth"],
    channels=channels,
    num_filters=model_conf["num_filters"],
    block_depth=model_conf["block_depth"],
    odp_block=sflax.inverse.ODPGrDescBlock,
    alpha_ini=1e-2,
)

Run training loop.

[10]:
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_ct_out")

train_conf["workdir"] = workdir
train_conf["post_lst"] = [alphapost]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
    train_conf,
    model,
    train_ds,
    test_ds,
)

start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time
Channels: 1, training signals: 320, testing signals: 32, signal size: 256
+---------------------------------------------------------+----------------+--------+-----------+--------+
| Name                                                    | Shape          | Size   | Mean      | Std    |
+---------------------------------------------------------+----------------+--------+-----------+--------+
| ODPGrDescBlock_0/alpha                                  | (1,)           | 1      | 0.01      | 0.0    |
| ODPGrDescBlock_0/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_0/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | -0.000308 | 0.0568 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 2.38e-05  | 0.0416 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000402  | 0.0418 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000185  | 0.0416 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_0/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000101 | 0.0418 |
| ODPGrDescBlock_0/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | 0.00276   | 0.058  |
| ODPGrDescBlock_1/alpha                                  | (1,)           | 1      | 0.005     | 0.0    |
| ODPGrDescBlock_1/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_1/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | 0.00128   | 0.0583 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000226 | 0.0419 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.00029   | 0.0415 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -5.14e-05 | 0.0417 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_1/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000319  | 0.0416 |
| ODPGrDescBlock_1/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | 0.000334  | 0.0583 |
| ODPGrDescBlock_2/alpha                                  | (1,)           | 1      | 0.0025    | 0.0    |
| ODPGrDescBlock_2/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_2/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | -0.00119  | 0.0602 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000443  | 0.0415 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000163  | 0.0416 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.0004   | 0.0417 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_2/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 6.25e-05  | 0.0417 |
| ODPGrDescBlock_2/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | 0.000515  | 0.0586 |
| ODPGrDescBlock_3/alpha                                  | (1,)           | 1      | 0.00125   | 0.0    |
| ODPGrDescBlock_3/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_3/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | -0.00179  | 0.057  |
| ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -3.06e-05 | 0.0417 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000153  | 0.0416 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000617  | 0.0418 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_3/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000218 | 0.0419 |
| ODPGrDescBlock_3/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | -0.000309 | 0.0575 |
| ODPGrDescBlock_4/alpha                                  | (1,)           | 1      | 0.000625  | 0.0    |
| ODPGrDescBlock_4/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_4/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | 0.00173   | 0.0585 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000285 | 0.0417 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000136 | 0.0416 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000243  | 0.0417 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_4/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000183  | 0.0417 |
| ODPGrDescBlock_4/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | 0.0033    | 0.0608 |
| ODPGrDescBlock_5/alpha                                  | (1,)           | 1      | 0.000312  | 0.0    |
| ODPGrDescBlock_5/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_5/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | 0.00178   | 0.0589 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000261 | 0.0417 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000299  | 0.0417 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000267 | 0.0418 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_5/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000261  | 0.0416 |
| ODPGrDescBlock_5/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | 0.00478   | 0.0608 |
| ODPGrDescBlock_6/alpha                                  | (1,)           | 1      | 0.000156  | 0.0    |
| ODPGrDescBlock_6/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_6/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | 0.000588  | 0.0576 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000272  | 0.0416 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 6.59e-05  | 0.0417 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 7.39e-05  | 0.0418 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_6/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -5e-05    | 0.0418 |
| ODPGrDescBlock_6/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | -0.000458 | 0.0583 |
| ODPGrDescBlock_7/alpha                                  | (1,)           | 1      | 7.81e-05  | 0.0    |
| ODPGrDescBlock_7/resnet/BatchNorm_0/bias                | (1,)           | 1      | 0.0       | 0.0    |
| ODPGrDescBlock_7/resnet/BatchNorm_0/scale               | (1,)           | 1      | 1.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_0/Conv_0/kernel     | (3, 3, 1, 64)  | 576    | 0.00317   | 0.0574 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_1/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -0.000238 | 0.0415 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_2/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000153  | 0.0418 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_3/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | -3.74e-05 | 0.0417 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/bias  | (64,)          | 64     | 0.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/scale | (64,)          | 64     | 1.0       | 0.0    |
| ODPGrDescBlock_7/resnet/ConvBNBlock_4/Conv_0/kernel     | (3, 3, 64, 64) | 36,864 | 0.000368  | 0.0418 |
| ODPGrDescBlock_7/resnet/Conv_0/kernel                   | (3, 3, 64, 1)  | 576    | -0.00299  | 0.0604 |
+---------------------------------------------------------+----------------+--------+-----------+--------+
Total: 1,194,008
+--------------------------------------------------------+-------+------+------+-----+
| Name                                                   | Shape | Size | Mean | Std |
+--------------------------------------------------------+-------+------+------+-----+
| ODPGrDescBlock_0/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_0/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_0/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_1/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_1/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_1/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_2/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_2/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_2/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_3/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_3/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_3/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_4/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_4/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_4/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_5/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_5/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_5/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_6/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_6/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_6/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_7/resnet/BatchNorm_0/mean               | (1,)  | 1    | 0.0  | 0.0 |
| ODPGrDescBlock_7/resnet/BatchNorm_0/var                | (1,)  | 1    | 1.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_0/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_1/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_2/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_3/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/mean | (64,) | 64   | 0.0  | 0.0 |
| ODPGrDescBlock_7/resnet/ConvBNBlock_4/BatchNorm_0/var  | (64,) | 64   | 1.0  | 0.0 |
+--------------------------------------------------------+-------+------+------+-----+
Total: 5,136
Initial compilation, this might take some minutes...
Initial compilation completed.
Epoch  Time      Train_LR  Train_Loss  Train_SNR  Eval_Loss  Eval_SNR
---------------------------------------------------------------------
    7  1.61e+02  0.001000    0.052377       3.85   0.132320     -4.61
   15  3.06e+02  0.001000    0.007551       7.91   0.126599     -4.42
   23  4.47e+02  0.001000    0.005597       9.15   0.086118     -2.74
   31  5.89e+02  0.001000    0.004938       9.69   0.081113     -2.48
   39  7.30e+02  0.001000    0.004495      10.10   0.028649      2.04
   47  8.73e+02  0.001000    0.004167      10.42   0.050053     -0.39
   55  1.02e+03  0.001000    0.003868      10.75   0.011462      6.03
   63  1.16e+03  0.001000    0.003687      10.96   0.004091     10.50
   71  1.30e+03  0.001000    0.006230      10.08  21614.328125    -56.68
   79  1.44e+03  0.001000    0.012242       6.12   0.019623      3.68
   87  1.59e+03  0.001000    0.005031       9.62   0.041972      0.38
   95  1.73e+03  0.001000    0.004204      10.39   0.012248      5.73
  103  1.87e+03  0.001000    0.003774      10.86   0.005177      9.47
  111  2.01e+03  0.001000    0.003523      11.15   0.004176     10.40
  119  2.16e+03  0.001000    0.003324      11.41   0.003917     10.68
  127  2.30e+03  0.001000    0.003194      11.58   0.004408     10.17
  135  2.44e+03  0.001000    0.003065      11.76   0.003885     10.72
  143  2.59e+03  0.001000    0.002968      11.90   0.004074     10.51
  151  2.73e+03  0.001000    0.002864      12.05   0.003808     10.80
  159  2.87e+03  0.001000    0.002801      12.15   0.003001     11.84
  167  3.02e+03  0.001000    0.002711      12.29   0.002757     12.21
  175  3.16e+03  0.001000    0.002627      12.43   0.002760     12.21
  183  3.30e+03  0.001000    0.002596      12.48   0.004782      9.84
  191  3.45e+03  0.001000    0.002539      12.58   0.003661     10.98
  199  3.59e+03  0.001000    0.002479      12.68   0.002647     12.39

Evaluate on testing data.

[11]:
del train_ds["image"]
del train_ds["label"]

fmap = sflax.FlaxMap(model, modvar)
del model, modvar

maxn = numtt
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = np.clip(output, a_min=0, a_max=1.0)
epochs = train_conf["num_epochs"]

Evaluate trained model in terms of reconstruction time and data fidelity.

[12]:
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
    f"{'ODPNet training':18s}{'epochs:':2s}{epochs:>5d}{'':21s}"
    f"{'time[s]:':10s}{time_train:>7.2f}"
)
print(
    f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
    f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)
ODPNet training   epochs:  200                     time[s]:  3590.61
ODPNet testing    SNR: 12.40 dB   PSNR: 22.78 dB   time[s]:    11.57

Plot comparison.

[13]:
np.random.seed(123)
indx = np.random.randint(0, high=maxn)

fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0])
plot.imview(
    test_ds["image"][indx, ..., 0],
    title="Sinogram",
    cbar=None,
    fig=fig,
    ax=ax[1],
)
plot.imview(
    output[indx, ..., 0],
    title="ODPNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f"
    % (
        metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
        metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
    ),
    fig=fig,
    ax=ax[2],
)
divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
../_images/examples_ct_astra_odp_train_foam2_25_0.png

Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint).

[14]:
if stats_object is not None and len(stats_object.iterations) > 0:
    hist = stats_object.history(transpose=True)
    fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
    plot.plot(
        np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
        x=hist.Epoch,
        ptyp="semilogy",
        title="Loss function",
        xlbl="Epoch",
        ylbl="Loss value",
        lgnd=("Train", "Test"),
        fig=fig,
        ax=ax[0],
    )
    plot.plot(
        np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
        x=hist.Epoch,
        title="Metric",
        xlbl="Epoch",
        ylbl="SNR (dB)",
        lgnd=("Train", "Test"),
        fig=fig,
        ax=ax[1],
    )
    fig.show()
../_images/examples_ct_astra_odp_train_foam2_27_0.png