Comparison of DnCNN Variants for Image Denoising

This example demonstrates the solution of an image denoising problem using DnCNN [64] networks trained for different noise levels, as well as custom variants with fewer network layers, and with a noise level input.

The networks trained for specific noise levels are labeled 6L, 6M, 6H, 17L, 17M, and 17H, where {6, 17} denote the number of layers, and {L, M, H} represent noise standard deviation of the training images (0.06, 0.10, and 0.20 respectively). The networks with a noise standard deviation input are labeled 6N and 17N, where {6, 17} again denote the number of layers.

[1]:
import numpy as np

import komplot as kplt
from xdesign import Foam, discrete_phantom

import scico.numpy as snp
import scico.random
from scico import metric
from scico.denoiser import DnCNN
kplt.config_notebook_plotting()

Create a ground truth image.

[2]:
np.random.seed(1234)
N = 512  # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = snp.array(x_gt)  # convert to jax array

Test different DnCNN variants on images with different noise levels.

[3]:
print("  σ   | variant | noisy image PSNR (dB)   | denoised image PSNR (dB)")
for σ in [0.06, 0.10, 0.20]:
    print("------+---------+-------------------------+-------------------------")
    for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]:
        # Instantiate a DnCNN.
        denoiser = DnCNN(variant=variant)

        # Generate a noisy image.
        noise, key = scico.random.randn(x_gt.shape, seed=0)
        y = x_gt + σ * noise

        if variant in ["6N", "17N"]:
            x_hat = denoiser(y, sigma=σ)
        else:
            x_hat = denoiser(y)

        x_hat = np.clip(x_hat, a_min=0, a_max=1.0)

        if variant[0] == "6":
            variant += " "  # add spaces to maintain alignment

        print(
            " %.2f | %s     |          %.2f          |          %.2f          "
            % (σ, variant, metric.psnr(x_gt, y), metric.psnr(x_gt, x_hat))
        )
  σ   | variant | noisy image PSNR (dB)   | denoised image PSNR (dB)
------+---------+-------------------------+-------------------------
 0.06 | 17L     |          24.44          |          33.87
 0.06 | 17M     |          24.44          |          34.05
 0.06 | 17H     |          24.44          |          26.34
 0.06 | 17N     |          24.44          |          35.66
 0.06 | 6L      |          24.44          |          33.90
 0.06 | 6M      |          24.44          |          29.80
 0.06 | 6H      |          24.44          |          26.89
 0.06 | 6N      |          24.44          |          36.54
------+---------+-------------------------+-------------------------
 0.10 | 17L     |          20.01          |          27.46
 0.10 | 17M     |          20.01          |          31.96
 0.10 | 17H     |          20.01          |          26.48
 0.10 | 17N     |          20.01          |          30.37
 0.10 | 6L      |          20.01          |          27.93
 0.10 | 6M      |          20.01          |          27.50
 0.10 | 6H      |          20.01          |          26.58
 0.10 | 6N      |          20.01          |          33.36
------+---------+-------------------------+-------------------------
 0.20 | 17L     |          13.99          |          18.40
 0.20 | 17M     |          13.99          |          20.13
 0.20 | 17H     |          13.99          |          26.07
 0.20 | 17N     |          13.99          |          21.31
 0.20 | 6L      |          13.99          |          18.73
 0.20 | 6M      |          13.99          |          20.73
 0.20 | 6H      |          13.99          |          24.87
 0.20 | 6N      |          13.99          |          25.92

Show reference and denoised images for σ=0.2 and variant=6N.

[4]:
fig, ax = kplt.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))
kplt.imview(x_gt, cmap="Blues", title="Reference", ax=ax[0])
kplt.imview(y, cmap="Blues", title="Noisy image: %.2f (dB)" % metric.psnr(x_gt, y), ax=ax[1])
kplt.imview(
    x_hat, cmap="Blues", title="Denoised image: %.2f (dB)" % metric.psnr(x_gt, x_hat), ax=ax[2]
)
fig.show()
../_images/examples_denoise_dncnn_universal_7_0.png