Source code for scico.denoiser

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Interfaces to standard denoisers."""


from typing import Any, Optional, Union

import numpy as np

import jax

try:
    import bm3d as tubm3d
except ImportError:
    have_bm3d = False
    BM3DProfile = Any
else:
    have_bm3d = True
    from bm3d.profiles import BM3DProfile  # type: ignore

try:
    import bm4d as tubm4d
except ImportError:
    have_bm4d = False
    BM4DProfile = Any
else:
    have_bm4d = True
    from bm4d.profiles import BM4DProfile  # type: ignore

import scico.numpy as snp
from scico.data import _flax_data_path
from scico.flax import DnCNNNet, FlaxMap, load_weights


[docs]def bm3d(x: snp.Array, sigma: float, is_rgb: bool = False, profile: Union[BM3DProfile, str] = "np"): r"""An interface to the BM3D denoiser :cite:`dabov-2008-image`. BM3D denoising is performed using the `code <https://pypi.org/project/bm3d>`__ released with :cite:`makinen-2019-exact`. Since this package is an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. Args: x: Input image. Expected to be a 2D array (gray-scale denoising) or 3D array (color denoising). Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. For color denoising, the color channel is assumed to be in the last non-singleton dimension. sigma: Noise parameter. is_rgb: Flag indicating use of BM3D with a color transform. Default: ``False``. profile: Parameter configuration for BM3D. Returns: Denoised output. """ if not have_bm3d: raise RuntimeError("Package bm3d is required for use of this function.") if is_rgb is True: def bm3d_eval(x: snp.Array, sigma: float): return tubm3d.bm3d_rgb(x, sigma, profile=profile) else: def bm3d_eval(x: snp.Array, sigma: float): return tubm3d.bm3d(x, sigma, profile=profile) if snp.util.is_complex_dtype(x.dtype): raise TypeError(f"BM3D requires real-valued inputs, got {x.dtype}.") # Support arrays with more than three axes when the additional axes are singletons. x_in_shape = x.shape if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( "BM3D requires two-dimensional or three dimensional inputs; got ndim = {x.ndim}." ) # This check is also performed inside the BM3D call, but due to the callback, # no exception is raised and the program will crash with no traceback. # NOTE: if BM3D is extended to allow for different profiles, the block size must be # updated; this presumes 'np' profile (bs=8) if profile == "np" and np.min(x.shape[:2]) < 8: raise ValueError( "Two leading dimensions of input cannot be smaller than block size " f"(8); got image size = {x.shape}." ) if x.ndim > 3: if all(k == 1 for k in x.shape[3:]): x = x.squeeze() else: raise ValueError( "Arrays with more than three axes are only supported when " " the additional axes are singletons." ) y = jax.pure_callback( lambda args: bm3d_eval(*args).astype(x.dtype), jax.ShapeDtypeStruct(x.shape, x.dtype), (x, sigma), ) # undo squeezing, if neccessary y = y.reshape(x_in_shape) return y
[docs]def bm4d(x: snp.Array, sigma: float, profile: Union[BM4DProfile, str] = "np"): r"""An interface to the BM4D denoiser :cite:`maggioni-2012-nonlocal`. BM4D denoising is performed using the `code <https://pypi.org/project/bm4d/>`__ released by the authors of :cite:`maggioni-2012-nonlocal`. Since this package is an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. Args: x: Input image. Expected to be a 3D array. Higher-dimensional arrays are tolerated only if the additional dimensions are singletons. sigma: Noise parameter. profile: Parameter configuration for BM4D. Returns: Denoised output. """ if not have_bm4d: raise RuntimeError("Package bm4d is required for use of this function.") def bm4d_eval(x: snp.Array, sigma: float): return tubm4d.bm4d(x, sigma, profile=profile) if snp.util.is_complex_dtype(x.dtype): raise TypeError(f"BM4D requires real-valued inputs, got {x.dtype}.") # Support arrays with more than three axes when the additional axes are singletons. x_in_shape = x.shape if isinstance(x.ndim, tuple) or x.ndim < 3: raise ValueError(f"BM4D requires three-dimensional inputs; got ndim = {x.ndim}.") # This check is also performed inside the BM4D call, but due to the callback, # no exception is raised and the program will crash with no traceback. # NOTE: if BM4D is extended to allow for different profiles, the block size must be # updated; this presumes 'np' profile (bs=8) if profile == "np" and np.min(x.shape[:3]) < 8: raise ValueError( "Three leading dimensions of input cannot be smaller than block size " f"(8); got image size = {x.shape}." ) if x.ndim > 3: if all(k == 1 for k in x.shape[3:]): x = x.squeeze() else: raise ValueError( "Arrays with more than three axes are only supported when " " the additional axes are singletons." ) y = jax.pure_callback( lambda args: bm4d_eval(*args).astype(x.dtype), jax.ShapeDtypeStruct(x.shape, x.dtype), (x, sigma), ) # undo squeezing, if neccessary y = y.reshape(x_in_shape) return y
[docs]class DnCNN(FlaxMap): """Flax implementation of the DnCNN denoiser. A flax implementation of the DnCNN denoiser :cite:`zhang-2017-dncnn`. Note that :class:`.DnCNNNet` represents an untrained form of the generic DnCNN CNN structure, while this class represents a trained form with six or seventeen layers. The standard DnCNN as proposed in :cite:`zhang-2017-dncnn` does not have a noise level input. This implementation of DnCNN also supports a custom variant that includes a noise standard deviation input, `sigma`, which is included in the network as an additional channel consisting of a constant array with value `sigma`. This network was trained with image data on the range [0, 1], and with noise standard deviations ranging from 0.0 to 0.2. It is worth noting that DRUNet :cite:`zhang-2021-plug`, another recent approach to including a noise level input in a CNN denoiser, is based on a substantially different network architecture. """ def __init__(self, variant: str = "6M"): """ Note that all DnCNN models are trained for single-channel image input. Multi-channel input is supported via independent denoising of each channel. Input images are expected to have pixel values in the range [0, 1]. Args: variant: Identify the DnCNN model to be used. Options are '6L', '6M' (default), '6H', '6N', '17L', '17M', '17H', and '17N', where the integer indicates the number of layers in the network, and the postfix indicates the training noise standard deviation (with respect to data in the range [0, 1]): L (low) = 0.06, M (mid) = 0.10, H (high) = 0.20, or N indicating that a noise standard deviation input, `sigma`, is available. """ self.variant = variant if variant not in ["6L", "6M", "6H", "17L", "17M", "17H", "6N", "17N"]: raise ValueError(f"Invalid value {variant} of parameter variant.") if variant[0] == "6": nlayer = 6 else: nlayer = 17 channels = 2 if variant in ["6N", "17N"] else 1 if variant in ["6N", "17N"]: self.is_blind = False else: self.is_blind = True model = DnCNNNet(depth=nlayer, channels=channels, num_filters=64, dtype=np.float32) variables = load_weights(_flax_data_path("dncnn%s.npz" % variant)) super().__init__(model, variables)
[docs] def __call__(self, x: snp.Array, sigma: Optional[float] = None) -> snp.Array: r"""Apply DnCNN denoiser. Args: x: Input array. sigma: Noise standard deviation (for variants `6N` and `17N`). Returns: Denoised output. """ if sigma is not None and self.is_blind: raise ValueError( "A non-default value for the sigma parameter may " "only be specified when the variant is 6N or 17N" f"; got variant = {self.variant}." ) if sigma is None and not self.is_blind: raise ValueError( "A float value must be specified for the sigma " "parameter when the variant is 6N or 17N." ) if snp.util.is_complex_dtype(x.dtype): raise TypeError(f"DnCNN requries real-valued inputs, got {x.dtype}.") if isinstance(x.ndim, tuple) or x.ndim < 2: raise ValueError( "DnCNN requires two-dimensional (M, N) or three-dimensional (M, N, C)" f" inputs; got ndim = {x.ndim}." ) x_in_shape = x.shape if x.ndim > 3: if all(k == 1 for k in x.shape[3:]): x = x.squeeze() else: raise ValueError( "Arrays with more than three axes are only supported when" " the additional axes are singletons." ) if x.ndim == 3: y = snp.swapaxes(x, 0, -1) if sigma is not None: y = snp.stack([y, snp.ones_like(y) * sigma], -1) else: y = y[..., np.newaxis] # swap channel axis to batch axis and add singleton axis at end y = super().__call__(y) # drop singleton axis and swap axes back to original positions y = snp.swapaxes(y[..., 0], 0, -1) else: if sigma is not None: x = snp.stack([x, snp.ones_like(x) * sigma], -1) x = x[np.newaxis, ...] y = super().__call__(x) if sigma is not None: y = y[0, ..., 0] y = y.reshape(x_in_shape) return y