PPP (with BM3D) Image Demosaicing

This example demonstrates the use of the ADMM Plug and Play Priors (PPP) algorithm [56], with the BM3D [17] denoiser, for solving a raw image demosaicing problem.

[1]:
import numpy as np

from bm3d import bm3d_rgb

# Workarounds for colour_demosaicing incompatibility with NumPy 2.x
np.float_ = np.float64
np.float = np.float64
np.complex = np.complex128
np.sctypes = {
    "float": [np.float16, np.float32, np.float64, np.longdouble],
    "int": [np.int8, np.int16, np.int32, np.int64],
}
import komplot as kplt
from colour_demosaicing import demosaicing_CFA_Bayer_Menon2007

import scico
import scico.numpy as snp
import scico.random
from scico import functional, linop, loss, metric
from scico.data import kodim23
from scico.optimize.admm import ADMM, LinearSubproblemSolver
from scico.util import device_info

Read a ground truth image.

[2]:
img = snp.array(kodim23(asfloat=True)[160:416, 60:316])

Define demosaicing forward operator and its transpose.

[3]:
def Afn(x):
    """Map an RGB image to a single channel image with each pixel
    representing a single colour according to the colour filter array.
    """

    y = snp.zeros(x.shape[0:2])
    y = y.at[1::2, 1::2].set(x[1::2, 1::2, 0])
    y = y.at[0::2, 1::2].set(x[0::2, 1::2, 1])
    y = y.at[1::2, 0::2].set(x[1::2, 0::2, 1])
    y = y.at[0::2, 0::2].set(x[0::2, 0::2, 2])
    return y


def ATfn(x):
    """Back project a single channel raw image to an RGB image with zeros
    at the locations of undefined samples.
    """

    y = snp.zeros(x.shape + (3,))
    y = y.at[1::2, 1::2, 0].set(x[1::2, 1::2])
    y = y.at[0::2, 1::2, 1].set(x[0::2, 1::2])
    y = y.at[1::2, 0::2, 1].set(x[1::2, 0::2])
    y = y.at[0::2, 0::2, 2].set(x[0::2, 0::2])
    return y

Define a baseline demosaicing function based on the demosaicing algorithm of [42] from package colour_demosaicing.

[4]:
def demosaic(cfaimg):
    """Apply baseline demosaicing."""
    return demosaicing_CFA_Bayer_Menon2007(cfaimg, pattern="BGGR").astype(np.float32)

Create a test image by color filter array sampling and adding Gaussian white noise.

[5]:
s = Afn(img)
rgbshp = s.shape + (3,)  # shape of reconstructed RGB image
σ = 2e-2  # noise standard deviation
noise, key = scico.random.randn(s.shape, seed=0)
sn = s + σ * noise

Compute a baseline demosaicing solution.

[6]:
imgb = snp.clip(snp.array(bm3d_rgb(demosaic(sn), 3 * σ).astype(np.float32)), 0.0, 1.0)

Set up an ADMM solver object. Note the use of the baseline solution as an initializer. We use BM3D [17] as the denoiser, using the code released with [40].

[7]:
A = linop.LinearOperator(input_shape=rgbshp, output_shape=s.shape, eval_fn=Afn, adj_fn=ATfn)
f = loss.SquaredL2Loss(y=sn, A=A)
C = linop.Identity(input_shape=rgbshp)
g = 1.8e-1 * 6.1e-2 * functional.BM3D(is_rgb=True)
ρ = 1.8e-1  # ADMM penalty parameter
maxiter = 12  # number of ADMM iterations

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[ρ],
    x0=imgb,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 100}),
    itstat_options={"display": True},
)

Run the solver.

[8]:
print(f"Solving on {device_info()}\n")
x = snp.clip(solver.solve(), 0.0, 1.0)
hist = solver.itstat_object.history(transpose=True)
Solving on GPU (NVIDIA GeForce RTX 2080 Ti)

Iter  Time      Prml Rsdl  Dual Rsdl  CG It  CG Res
------------------------------------------------------
   0  8.54e+00  2.458e+00  4.229e-01      1  1.555e-09
   1  1.42e+01  2.028e+00  1.548e-01      2  1.785e-09
   2  1.90e+01  1.522e+00  2.045e-01      2  1.006e-09
   3  2.47e+01  1.160e+00  2.717e-01      2  6.849e-10
   4  2.98e+01  9.417e-01  2.798e-01      2  7.247e-10
   5  3.57e+01  8.315e-01  2.339e-01      2  6.042e-10
   6  4.06e+01  7.463e-01  1.658e-01      2  5.458e-10
   7  4.41e+01  5.972e-01  8.401e-02      1  8.530e-04
   8  4.74e+01  5.177e-01  1.312e-01      2  2.062e-10
   9  5.07e+01  4.256e-01  1.308e-01      2  2.315e-10
  10  5.40e+01  3.793e-01  1.263e-01      2  2.016e-10
  11  5.74e+01  2.749e-01  8.099e-02      1  9.929e-04

Show reference and demosaiced images.

[9]:
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))
kplt.imview(img, title="Reference", ax=ax[0])
kplt.imview(imgb, title="Baseline demoisac: %.2f (dB)" % metric.psnr(img, imgb), ax=ax[1])
kplt.imview(x, title="PPP demoisac: %.2f (dB)" % metric.psnr(img, x), ax=ax[2])
fig.show()
../_images/examples_demosaic_ppp_bm3d_admm_17_0.png

Plot convergence statistics.

[10]:
kplt.plot(
    snp.array((hist.Prml_Rsdl, hist.Dual_Rsdl)).T,
    ylog=True,
    title="Residuals",
    xlabel="Iteration",
    legend=("Primal", "Dual"),
)
[10]:
<komplot.LinePlot at 0x7c5cd01e8980>
../_images/examples_demosaic_ppp_bm3d_admm_19_1.png