Comparison of DnCNN Variants for Image Denoising#

This example demonstrates the solution of an image denoising problem using DnCNN [58] networks trained for different noise levels, as well as custom variants with fewer network layers, and with a noise level input.

The networks trained for specific noise levels are labeled 6L, 6M, 6H, 17L, 17M, and 17H, where {6, 17} denote the number of layers, and {L, M, H} represent noise standard deviation of the training images (0.06, 0.10, and 0.20 respectively). The networks with a noise standard deviation input are labeled 6N and 17N, where {6, 17} again denote the number of layers.

[1]:
import numpy as np

from xdesign import Foam, discrete_phantom

import scico.numpy as snp
import scico.random
from scico import metric, plot
from scico.denoiser import DnCNN
plot.config_notebook_plotting()

Create a ground truth image.

[2]:
np.random.seed(1234)
N = 512  # image size
x_gt = discrete_phantom(Foam(size_range=[0.075, 0.0025], gap=1e-3, porosity=1), size=N)
x_gt = snp.array(x_gt)  # convert to jax array

Test different DnCNN variants on images with different noise levels.

[3]:
print("  σ   | variant | noisy image PSNR (dB)   | denoised image PSNR (dB)")
for σ in [0.06, 0.10, 0.20]:
    print("------+---------+-------------------------+-------------------------")
    for variant in ["17L", "17M", "17H", "17N", "6L", "6M", "6H", "6N"]:
        # Instantiate a DnCNN.
        denoiser = DnCNN(variant=variant)

        # Generate a noisy image.
        noise, key = scico.random.randn(x_gt.shape, seed=0)
        y = x_gt + σ * noise

        if variant in ["6N", "17N"]:
            x_hat = denoiser(y, sigma=σ)
        else:
            x_hat = denoiser(y)

        x_hat = np.clip(x_hat, a_min=0, a_max=1.0)

        if variant[0] == "6":
            variant += " "  # add spaces to maintain alignment

        print(
            " %.2f | %s     |          %.2f          |          %.2f          "
            % (σ, variant, metric.psnr(x_gt, y), metric.psnr(x_gt, x_hat))
        )
  σ   | variant | noisy image PSNR (dB)   | denoised image PSNR (dB)
------+---------+-------------------------+-------------------------
 0.06 | 17L     |          24.43          |          33.82
 0.06 | 17M     |          24.43          |          33.94
 0.06 | 17H     |          24.43          |          26.32
 0.06 | 17N     |          24.43          |          35.48
 0.06 | 6L      |          24.43          |          33.80
 0.06 | 6M      |          24.43          |          29.76
 0.06 | 6H      |          24.43          |          26.86
 0.06 | 6N      |          24.43          |          36.30
------+---------+-------------------------+-------------------------
 0.10 | 17L     |          19.99          |          27.43
 0.10 | 17M     |          19.99          |          31.82
 0.10 | 17H     |          19.99          |          26.44
 0.10 | 17N     |          19.99          |          30.30
 0.10 | 6L      |          19.99          |          27.87
 0.10 | 6M      |          19.99          |          27.45
 0.10 | 6H      |          19.99          |          26.52
 0.10 | 6N      |          19.99          |          33.09
------+---------+-------------------------+-------------------------
 0.20 | 17L     |          13.97          |          18.37
 0.20 | 17M     |          13.97          |          20.12
 0.20 | 17H     |          13.97          |          25.97
 0.20 | 17N     |          13.97          |          21.38
 0.20 | 6L      |          13.97          |          18.70
 0.20 | 6M      |          13.97          |          20.70
 0.20 | 6H      |          13.97          |          24.78
 0.20 | 6N      |          13.97          |          25.71

Show reference and denoised images for σ=0.2 and variant=6N.

[4]:
fig, ax = plot.subplots(nrows=1, ncols=3, sharex=True, sharey=True, figsize=(21, 7))
plot.imview(x_gt, title="Reference", fig=fig, ax=ax[0])
plot.imview(y, title="Noisy image: %.2f (dB)" % metric.psnr(x_gt, y), fig=fig, ax=ax[1])
plot.imview(x_hat, title="Denoised image: %.2f (dB)" % metric.psnr(x_gt, x_hat), fig=fig, ax=ax[2])
fig.show()
../_images/examples_denoise_dncnn_universal_7_0.png