PPP (with BM3D) Image Demosaicing#

This example demonstrates the use of the ADMM Plug and Play Priors (PPP) algorithm [50], with the BM3D [16] denoiser, for solving a raw image demosaicing problem.

[1]:
import numpy as np

from bm3d import bm3d_rgb
from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007

import scico
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, metric, plot
from scico.data import kodim23
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info
plot.config_notebook_plotting()

Read a ground truth image.

[2]:
img = snp.array(kodim23(asfloat=True)[160:416, 60:316])

Define demosaicing forward operator and its transpose.

[3]:
def Afn(x):
    """Map an RGB image to a single channel image with each pixel
    representing a single colour according to the colour filter array.
    """

    y = snp.zeros(x.shape[0:2])
    y = y.at[1::2, 1::2].set(x[1::2, 1::2, 0])
    y = y.at[0::2, 1::2].set(x[0::2, 1::2, 1])
    y = y.at[1::2, 0::2].set(x[1::2, 0::2, 1])
    y = y.at[0::2, 0::2].set(x[0::2, 0::2, 2])
    return y


def ATfn(x):
    """Back project a single channel raw image to an RGB image with zeros
    at the locations of undefined samples.
    """

    y = snp.zeros(x.shape + (3,))
    y = y.at[1::2, 1::2, 0].set(x[1::2, 1::2])
    y = y.at[0::2, 1::2, 1].set(x[0::2, 1::2])
    y = y.at[1::2, 0::2, 1].set(x[1::2, 0::2])
    y = y.at[0::2, 0::2, 2].set(x[0::2, 0::2])
    return y

Define a baseline demosaicing function based on the demosaicing algorithm of [37] from package colour_demosaicing.

[4]:
def demosaic(cfaimg):
    """Apply baseline demosaicing."""
    return demosaicing_CFA_Bayer_Menon2007(cfaimg, pattern="BGGR").astype(np.float32)

Create a test image by color filter array sampling and adding Gaussian white noise.

[5]:
s = Afn(img)
rgbshp = s.shape + (3,)  # shape of reconstructed RGB image
σ = 2e-2  # noise standard deviation
noise, key = scico.random.randn(s.shape, seed=0)
sn = s + σ * noise

Compute a baseline demosaicing solution.

[6]:
imgb = snp.array(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32))

Set up an ADMM solver object. Note the use of the baseline solution as an initializer. We use BM3D [16] as the denoiser, using the code released with [35].

[7]:
A = linop.LinearOperator(input_shape=rgbshp, output_shape=s.shape, eval_fn=Afn, adj_fn=ATfn)
f = loss.SquaredL2Loss(y=sn, A=A)
C = linop.Identity(input_shape=rgbshp)
g = 1.8e-1 * 6.1e-2 * functional.BM3D(is_rgb=True)
ρ = 1.8e-1  # ADMM penalty parameter
maxiter = 12  # number of ADMM iterations

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[ρ],
    x0=imgb,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
    itstat_options={"display": True},
)

Run the solver.

[8]:
print(f"Solving on {device_info()}\n")
x = solver.solve()
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Prml Rsdl  Dual Rsdl  CG It  CG Res
------------------------------------------------------
   0  7.66e+00  5.788e+00  2.298e+00      1  1.971e-09
   1  1.39e+01  4.773e+00  8.708e-01      2  2.083e-09
   2  2.02e+01  3.597e+00  1.122e+00      2  9.364e-10
   3  2.59e+01  2.751e+00  1.480e+00      2  1.387e-09
   4  3.24e+01  2.231e+00  1.549e+00      2  1.406e-09
   5  3.86e+01  1.961e+00  1.283e+00      2  5.103e-10
   6  4.42e+01  1.758e+00  9.265e-01      2  6.186e-10
   7  5.01e+01  1.400e+00  4.759e-01      1  8.766e-04
   8  5.63e+01  1.216e+00  7.355e-01      2  3.677e-10
   9  6.24e+01  9.984e-01  7.076e-01      2  2.950e-10
  10  6.80e+01  8.883e-01  6.705e-01      2  3.903e-10
  11  7.34e+01  6.458e-01  4.391e-01      1  9.664e-04

Show reference and demosaiced images.

[9]:
fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))
plot.imview(img, title="Reference", fig=fig, ax=ax[0])
plot.imview(imgb, title="Baseline demoisac: %.2f (dB)" % metric.psnr(img, imgb), fig=fig, ax=ax[1])
plot.imview(x, title="PPP demoisac: %.2f (dB)" % metric.psnr(img, x), fig=fig, ax=ax[2])
fig.show()
../_images/examples_demosaic_ppp_bm3d_admm_17_0.png

Plot convergence statistics.

[10]:
plot.plot(
    snp.vstack((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
    ptyp="semilogy",
    title="Residuals",
    xlbl="Iteration",
    lgnd=("Primal", "Dual"),
)
../_images/examples_demosaic_ppp_bm3d_admm_19_0.png