Deconvolution Training and Reconstructions with ODP#
This example demonstrates the training and application of the unrolled optimization with deep priors (ODP) with proximal map architecture described in [19] for a deconvolution (deblurring) problem.
The source images are foam phantoms generated with xdesign.
A class scico.flax.ODPNet implements the ODP architecture, which solves the optimization problem
where \(A\) is a circular convolution, \(\mathbf{y}\) is a set of blurred images, \(r\) is a regularizer and \(\mathbf{x}\) is the set of deblurred images. The ODP, proximal map architecture, abstracts the iterative solution by an unrolled network where each iteration corresponds to a different stage in the ODP network and updates the prediction by solving
which for the deconvolution problem corresponds to
where \(k\) is the index of the stage (iteration), \(\mathbf{x}^k + \mathbf{x}^{k+1/2} = \mathrm{ResNet}(\mathbf{x}^{k})\) is the regularization (implemented as a residual convolutional neural network), \(\mathbf{x}^k\) is the output of the previous stage, \(\alpha_k > 0\) is a learned stage-wise parameter weighting the contribution of the fidelity term, \(\mathcal{F}\) is the DFT, \(K\) is the blur kernel, and \(\mathcal{K}\) is the DFT of \(K\). The output of the final stage is the set of deblurred images.
[1]:
import os
from functools import partial
from time import time
import numpy as np
import jax
from mpl_toolkits.axes_grid1 import make_axes_locatable
from scico import flax as sflax
from scico import metric, plot
from scico.flax.examples import load_foam1_blur_data
from scico.flax.train.traversals import clip_positive, construct_traversal
from scico.linop import CircularConvolve
plot.config_notebook_plotting()
Prepare parallel processing. Set an arbitrary processor count (only applies if GPU is not available).
[2]:
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"
platform = jax.lib.xla_bridge.get_backend().platform
print("Platform: ", platform)
Platform: gpu
Define blur operator.
[3]:
output_size = 256 # patch size
n = 3 # convolution kernel size
σ = 20.0 / 255 # noise level
psf = np.ones((n, n)) / (n * n) # blur kernel
ishape = (output_size, output_size)
opBlur = CircularConvolve(h=psf, input_shape=ishape)
opBlur_vmap = jax.vmap(opBlur) # for batch processing in data generation
Read data from cache or generate if not available.
[4]:
train_nimg = 416 # number of training images
test_nimg = 64 # number of testing images
nimg = train_nimg + test_nimg
train_ds, test_ds = load_foam1_blur_data(
train_nimg,
test_nimg,
output_size,
psf,
σ,
verbose=True,
)
Data read from path : ~/.cache/scico/examples/data
Set --training-- : Size: 416
Set --testing -- : Size: 64
Data range -- images -- : Min: 0.00, Max: 1.00
Data range -- labels -- : Min: 0.00, Max: 1.00
Define configuration dictionary for model and training loop.
Parameters have been selected for demonstration purposes and relatively short training. The model depth is akin to the number of unrolled iterations in the ODP model. The block depth controls the number of layers at each unrolled iteration. The number of filters is uniform throughout the iterations. Better performance may be obtained by increasing depth, block depth, number of filters or training epochs, but may require longer training times.
[5]:
# model configuration
model_conf = {
"depth": 2,
"num_filters": 64,
"block_depth": 3,
}
# training configuration
train_conf: sflax.ConfigDict = {
"seed": 0,
"opt_type": "SGD",
"momentum": 0.9,
"batch_size": 16,
"num_epochs": 50,
"base_learning_rate": 1e-2,
"warmup_epochs": 0,
"log_every_steps": 100,
"log": True,
"checkpointing": True,
}
Construct ODPNet model.
[6]:
channels = train_ds["image"].shape[-1]
model = sflax.ODPNet(
operator=opBlur,
depth=model_conf["depth"],
channels=channels,
num_filters=model_conf["num_filters"],
block_depth=model_conf["block_depth"],
odp_block=sflax.inverse.ODPProxDcnvBlock,
)
Construct functionality for making sure that the learned fidelity weight parameter is always positive.
[7]:
alphatrav = construct_traversal("alpha") # select alpha parameters in model
alphapos = partial(
clip_positive, # apply this function
traversal=alphatrav, # to alpha parameters in model
minval=1e-3,
)
Run training loop.
[8]:
workdir = os.path.join(os.path.expanduser("~"), ".cache", "scico", "examples", "odp_dcnv_out")
print(f"{'JAX process: '}{jax.process_index()}{' / '}{jax.process_count()}")
print(f"{'JAX local devices: '}{jax.local_devices()}")
train_conf["workdir"] = workdir
train_conf["post_lst"] = [alphapos]
# Construct training object
trainer = sflax.BasicFlaxTrainer(
train_conf,
model,
train_ds,
test_ds,
)
start_time = time()
modvar, stats_object = trainer.train()
time_train = time() - start_time
JAX process: 0 / 1
JAX local devices: [cuda(id=0), cuda(id=1), cuda(id=2), cuda(id=3), cuda(id=4), cuda(id=5), cuda(id=6), cuda(id=7)]
Channels: 1, training signals: 416, testing signals: 64, signal size: 256
+-----------------------------------------------------------+----------------+--------+-----------+--------+
| Name | Shape | Size | Mean | Std |
+-----------------------------------------------------------+----------------+--------+-----------+--------+
| ODPProxDcnvBlock_0/alpha | (1,) | 1 | 0.5 | 0.0 |
| ODPProxDcnvBlock_0/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | -0.00178 | 0.0562 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | -0.000448 | 0.0418 |
| ODPProxDcnvBlock_0/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | 0.00288 | 0.0593 |
| ODPProxDcnvBlock_1/alpha | (1,) | 1 | 0.25 | 0.0 |
| ODPProxDcnvBlock_1/resnet/BatchNorm_0/bias | (1,) | 1 | 0.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/BatchNorm_0/scale | (1,) | 1 | 1.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_0/Conv_0/kernel | (3, 3, 1, 64) | 576 | -0.00168 | 0.0563 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/bias | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/scale | (64,) | 64 | 1.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_1/Conv_0/kernel | (3, 3, 64, 64) | 36,864 | 7.01e-05 | 0.0415 |
| ODPProxDcnvBlock_1/resnet/Conv_0/kernel | (3, 3, 64, 1) | 576 | -1.25e-05 | 0.06 |
+-----------------------------------------------------------+----------------+--------+-----------+--------+
Total: 76,550
+----------------------------------------------------------+-------+------+------+-----+
| Name | Shape | Size | Mean | Std |
+----------------------------------------------------------+-------+------+------+-----+
| ODPProxDcnvBlock_0/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_0/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/BatchNorm_0/mean | (1,) | 1 | 0.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/BatchNorm_0/var | (1,) | 1 | 1.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_0/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/mean | (64,) | 64 | 0.0 | 0.0 |
| ODPProxDcnvBlock_1/resnet/ConvBNBlock_1/BatchNorm_0/var | (64,) | 64 | 1.0 | 0.0 |
+----------------------------------------------------------+-------+------+------+-----+
Total: 516
Initial compilation, this might take some minutes...
Initial compilation completed.
Epoch Time Train_LR Train_Loss Train_SNR Eval_Loss Eval_SNR
---------------------------------------------------------------------
3 5.05e+01 0.010000 0.083444 9.13 0.047218 3.18
7 9.48e+01 0.010000 0.005381 12.59 0.067449 1.63
11 9.85e+01 0.010000 0.004385 13.47 0.099316 -0.05
15 1.02e+02 0.010000 0.003794 14.10 0.122478 -0.96
19 1.05e+02 0.010000 0.003404 14.57 0.105635 -0.31
23 1.08e+02 0.010000 0.003132 14.93 0.050777 2.87
26 1.12e+02 0.010000 0.002935 15.21 0.018653 7.22
30 1.15e+02 0.010000 0.002788 15.43 0.007561 11.14
34 1.19e+02 0.010000 0.002673 15.62 0.004310 13.58
38 1.22e+02 0.010000 0.002580 15.77 0.003203 14.87
42 1.25e+02 0.010000 0.002504 15.90 0.002743 15.54
46 1.29e+02 0.010000 0.002440 16.01 0.002598 15.78
49 1.32e+02 0.010000 0.002385 16.11 0.002466 16.00
Evaluate on testing data.
[9]:
del train_ds["image"]
del train_ds["label"]
fmap = sflax.FlaxMap(model, modvar)
del model, modvar
maxn = test_nimg // 4
start_time = time()
output = fmap(test_ds["image"][:maxn])
time_eval = time() - start_time
output = np.clip(output, a_min=0, a_max=1.0)
Evaluate trained model in terms of reconstruction time and data fidelity.
[10]:
snr_eval = metric.snr(test_ds["label"][:maxn], output)
psnr_eval = metric.psnr(test_ds["label"][:maxn], output)
print(
f"{'ODPNet training':18s}{'epochs:':2s}{train_conf['num_epochs']:>5d}"
f"{'':21s}{'time[s]:':10s}{time_train:>7.2f}"
)
print(
f"{'ODPNet testing':18s}{'SNR:':5s}{snr_eval:>5.2f}{' dB'}{'':3s}"
f"{'PSNR:':6s}{psnr_eval:>5.2f}{' dB'}{'':3s}{'time[s]:':10s}{time_eval:>7.2f}"
)
ODPNet training epochs: 50 time[s]: 133.18
ODPNet testing SNR: 16.39 dB PSNR: 23.47 dB time[s]: 7.53
Plot comparison.
[11]:
np.random.seed(123)
indx = np.random.randint(0, high=maxn)
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(15, 5))
plot.imview(test_ds["label"][indx, ..., 0], title="Ground truth", cbar=None, fig=fig, ax=ax[0])
plot.imview(
test_ds["image"][indx, ..., 0],
title="Blurred: \nSNR: %.2f (dB), PSNR: %.2f"
% (
metric.snr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]),
metric.psnr(test_ds["label"][indx, ..., 0], test_ds["image"][indx, ..., 0]),
),
cbar=None,
fig=fig,
ax=ax[1],
)
plot.imview(
output[indx, ..., 0],
title="ODPNet Reconstruction\nSNR: %.2f (dB), PSNR: %.2f"
% (
metric.snr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
metric.psnr(test_ds["label"][indx, ..., 0], output[indx, ..., 0]),
),
fig=fig,
ax=ax[2],
)
divider = make_axes_locatable(ax[2])
cax = divider.append_axes("right", size="5%", pad=0.2)
fig.colorbar(ax[2].get_images()[0], cax=cax, label="arbitrary units")
fig.show()
Plot convergence statistics. Statistics are generated only if a training cycle was done (i.e. if not reading final epoch results from checkpoint).
[12]:
if stats_object is not None and len(stats_object.iterations) > 0:
hist = stats_object.history(transpose=True)
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(12, 5))
plot.plot(
np.vstack((hist.Train_Loss, hist.Eval_Loss)).T,
x=hist.Epoch,
ptyp="semilogy",
title="Loss function",
xlbl="Epoch",
ylbl="Loss value",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[0],
)
plot.plot(
np.vstack((hist.Train_SNR, hist.Eval_SNR)).T,
x=hist.Epoch,
title="Metric",
xlbl="Epoch",
ylbl="SNR (dB)",
lgnd=("Train", "Test"),
fig=fig,
ax=ax[1],
)
fig.show()