Deconvolution Microscopy (All Channels)#

This example partially replicates a GlobalBioIm example using the microscopy data provided by the EPFL Biomedical Imaging Group.

The deconvolution problem is solved using class admm.ADMM to solve an image deconvolution problem with isotropic total variation (TV) regularization

\[\mathrm{argmin}_{\mathbf{x}} \; (1/2) \| M (\mathbf{y} - A \mathbf{x}) \|_2^2 + \lambda \| C \mathbf{x} \|_{2,1} + \iota_{\mathrm{NN}}(\mathbf{x}) \;,\]

where \(M\) is a mask operator, \(A\) is circular convolution, \(\mathbf{y}\) is the blurred image, \(C\) is a convolutional gradient operator, \(\iota_{\mathrm{NN}}\) is the indicator function of the non-negativity constraint, and \(\mathbf{x}\) is the desired image.

[1]:
import numpy as np

import ray
import scico.numpy as snp
from scico import functional, linop, loss, plot
from scico.examples import downsample_volume, epfl_deconv_data, tile_volume_slices
from scico.optimize.admm import ADMM, CircularConvolveSolver
plot.config_notebook_plotting()

Get and preprocess data. We downsample the data for the for purposes of the example. Reducing the downsampling rate will make the example slower and more memory-intensive. To run this example on a GPU it may be necessary to set environment variables XLA_PYTHON_CLIENT_ALLOCATOR=platform and XLA_PYTHON_CLIENT_PREALLOCATE=false. If your GPU does not have enough memory, you can try setting the environment variable JAX_PLATFORM_NAME=cpu to run on CPU.

[2]:
downsampling_rate = 2

y_list = []
y_pad_list = []
psf_list = []
for channel in range(3):
    y, psf = epfl_deconv_data(channel, verbose=True)  # get data
    y = downsample_volume(y, downsampling_rate)  # downsample
    psf = downsample_volume(psf, downsampling_rate)
    y -= y.min()  # normalize y
    y /= y.max()
    psf /= psf.sum()  # normalize psf
    if channel == 0:
        padding = [[0, p] for p in snp.array(psf.shape) - 1]
        mask = snp.pad(snp.ones_like(y), padding)
    y_pad = snp.pad(y, padding)  # zero-padded version of y
    y_list.append(y)
    y_pad_list.append(y_pad)
    psf_list.append(psf)
y = snp.stack(y_list, axis=-1)
yshape = y.shape
del y_list

Define problem and algorithm parameters.

[3]:
λ = 2e-6  # ℓ1 norm regularization parameter
ρ0 = 1e-3  # ADMM penalty parameter for first auxiliary variable
ρ1 = 1e-3  # ADMM penalty parameter for second auxiliary variable
ρ2 = 1e-3  # ADMM penalty parameter for third auxiliary variable
maxiter = 100  # number of ADMM iterations

Initialize ray, determine available computing resources, and put large arrays in object store.

[4]:
ray.init()

ngpu = 0
ar = ray.available_resources()
ncpu = max(int(ar["CPU"]) // 3, 1)
if "GPU" in ar:
    ngpu = int(ar["GPU"]) // 3
print(f"Running on {ncpu} CPUs and {ngpu} GPUs per process")

y_pad_list = ray.put(y_pad_list)
psf_list = ray.put(psf_list)
mask_store = ray.put(mask)
Running on 26 CPUs and 2 GPUs per process

Define ray remote function for parallel solves.

[5]:
@ray.remote(num_cpus=ncpu, num_gpus=ngpu)
def deconvolve_channel(channel):
    """Deconvolve a single channel."""
    y_pad = ray.get(y_pad_list)[channel]
    psf = ray.get(psf_list)[channel]
    mask = ray.get(mask_store)
    M = linop.Diagonal(mask)
    C0 = linop.CircularConvolve(
        h=psf, input_shape=mask.shape, h_center=snp.array(psf.shape) / 2 - 0.5  # forward operator
    )
    C1 = linop.FiniteDifference(input_shape=mask.shape, circular=True)  # gradient operator
    C2 = linop.Identity(mask.shape)  # identity operator
    g0 = loss.SquaredL2Loss(y=y_pad, A=M)  # loss function (forward model)
    g1 = λ * functional.L21Norm()  # TV penalty (when applied to gradient)
    g2 = functional.NonNegativeIndicator()  # non-negativity constraint
    if channel == 0:
        print("Displaying solver status for channel 0")
        display = True
    else:
        display = False
    solver = ADMM(
        f=None,
        g_list=[g0, g1, g2],
        C_list=[C0, C1, C2],
        rho_list=[ρ0, ρ1, ρ2],
        maxiter=maxiter,
        itstat_options={"display": display, "period": 10, "overwrite": False},
        x0=y_pad,
        subproblem_solver=CircularConvolveSolver(),
    )
    x_pad = solver.solve()
    x = x_pad[: yshape[0], : yshape[1], : yshape[2]]
    return (x, solver.itstat_object.history(transpose=True))

Solve problems for all three channels in parallel and extract results.

[6]:
ray_return = ray.get([deconvolve_channel.remote(channel) for channel in range(3)])
x = snp.stack([t[0] for t in ray_return], axis=-1)
solve_stats = [t[1] for t in ray_return]
(deconvolve_channel pid=44343) Displaying solver status for channel 0
(deconvolve_channel pid=44343) Iter  Time      Objective  Prml Rsdl  Dual Rsdl
(deconvolve_channel pid=44343) -----------------------------------------------
(deconvolve_channel pid=44343)    0  1.31e+01  2.830e-01  2.780e+02  1.949e+02
(deconvolve_channel pid=44343)   10  2.35e+01  1.033e+00  5.524e+01  9.689e+01
(deconvolve_channel pid=44343)   20  3.23e+01  1.158e+00  7.474e+01  4.940e+01
(deconvolve_channel pid=44343)   30  4.11e+01  1.297e+00  4.547e+01  5.652e+01
(deconvolve_channel pid=44343)   40  4.90e+01  1.528e+00  4.879e+01  3.606e+01
(deconvolve_channel pid=44343)   50  5.67e+01  1.733e+00  3.740e+01  3.763e+01
(deconvolve_channel pid=44343)   60  6.37e+01  1.919e+00  3.390e+01  3.228e+01
(deconvolve_channel pid=44343)   70  7.06e+01  2.133e+00  2.996e+01  2.785e+01
(deconvolve_channel pid=44343)   80  7.75e+01  2.341e+00  2.456e+01  2.723e+01
(deconvolve_channel pid=44343)   90  8.44e+01  2.560e+00  2.437e+01  2.325e+01

Show the recovered image.

[7]:
fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(14, 7))
plot.imview(tile_volume_slices(y), title="Blurred measurements", fig=fig, ax=ax[0])
plot.imview(tile_volume_slices(x), title="Deconvolved image", fig=fig, ax=ax[1])
fig.show()
../_images/examples_deconv_microscopy_allchn_tv_admm_13_0.png

Plot convergence statistics.

[8]:
fig, ax = plot.subplots(nrows=1, ncols=3, figsize=(18, 5))
plot.plot(
    np.stack([s.Objective for s in solve_stats]).T,
    title="Objective function",
    xlbl="Iteration",
    ylbl="Functional value",
    lgnd=("CY3", "DAPI", "FITC"),
    fig=fig,
    ax=ax[0],
)
plot.plot(
    np.stack([s.Prml_Rsdl for s in solve_stats]).T,
    ptyp="semilogy",
    title="Primal Residual",
    xlbl="Iteration",
    lgnd=("CY3", "DAPI", "FITC"),
    fig=fig,
    ax=ax[1],
)
plot.plot(
    np.stack([s.Dual_Rsdl for s in solve_stats]).T,
    ptyp="semilogy",
    title="Dual Residual",
    xlbl="Iteration",
    lgnd=("CY3", "DAPI", "FITC"),
    fig=fig,
    ax=ax[2],
)
fig.show()
../_images/examples_deconv_microscopy_allchn_tv_admm_15_0.png