Source code for scico.functional._denoiser

# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Pseudo-functionals that have denoisers as their proximal operators."""


from typing import Union

from scico import denoiser
from scico.numpy import Array

from ._functional import Functional


class BM3D(Functional):
    r"""Pseudo-functional whose prox applies the BM3D denoising algorithm.

    A pseudo-functional that has the BM3D algorithm
    :cite:`dabov-2008-image` as its proximal operator, which calls
    :func:`.denoiser.bm3d`. Since this function provides an interface
    to compiled C code, JAX features such as automatic differentiation
    and support for GPU devices are not available.
    """

    has_eval = False
    has_prox = True

    def __init__(self, is_rgb: bool = False, profile: Union[denoiser.BM3DProfile, str] = "np"):
        r"""Initialize a :class:`BM3D` object.

        Args:
            is_rgb: Flag indicating use of BM3D with a color transform.
                    Default: ``False``.
            profile: Parameter configuration for BM3D.
        """

        self.is_rgb = is_rgb
        self.profile = profile
        super().__init__()

[docs] def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore r"""Apply BM3D denoiser. Args: x: Input image. lam: Noise parameter. kwargs: Additional arguments that may be used by derived classes. Returns: Denoised output. """ return denoiser.bm3d(x, lam, self.is_rgb, profile=self.profile)
class BM4D(Functional): r"""Pseudo-functional whose prox applies the BM4D denoising algorithm. A pseudo-functional that has the BM4D algorithm :cite:`maggioni-2012-nonlocal` as its proximal operator, which calls :func:`.denoiser.bm4d`. Since this function provides an interface to compiled C code, JAX features such as automatic differentiation and support for GPU devices are not available. """ has_eval = False has_prox = True def __init__(self, profile: Union[denoiser.BM4DProfile, str] = "np"): r"""Initialize a :class:`BM4D` object. Args: profile: Parameter configuration for BM4D. """ self.profile = profile super().__init__()
[docs] def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore r"""Apply BM4D denoiser. Args: x: Input image. lam: Noise parameter. kwargs: Additional arguments that may be used by derived classes. Returns: Denoised output. """ return denoiser.bm4d(x, lam, profile=self.profile)
class DnCNN(Functional): """Pseudo-functional whose prox applies the DnCNN denoising algorithm. A pseudo-functional that has the DnCNN algorithm :cite:`zhang-2017-dncnn` as its proximal operator, implemented via :class:`.denoiser.DnCNN`. """ has_eval = False has_prox = True def __init__(self, variant: str = "6M"): """ Args: variant: Identify the DnCNN model to be used. See :class:`.denoiser.DnCNN` for valid values. """ self.dncnn = denoiser.DnCNN(variant) if self.dncnn.is_blind: def denoise(x, sigma): return self.dncnn(x) else: def denoise(x, sigma): return self.dncnn(x, sigma) self._denoise = denoise
[docs] def prox(self, x: Array, lam: float = 1.0, **kwargs) -> Array: # type: ignore r"""Apply DnCNN denoiser. *Warning*: The `lam` parameter is ignored, and has no effect on the output. Args: x: Input array. lam: Noise parameter (ignored). kwargs: Additional arguments that may be used by derived classes. Returns: Denoised output. """ return self._denoise(x, lam)