# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Flax implementation of different convolutional nets."""
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
from functools import partial
from typing import Any, Callable, Tuple
import jax.numpy as jnp
from flax.core import Scope # noqa
from flax.linen import BatchNorm, Conv, max_pool, relu
from flax.linen.initializers import kaiming_normal, xavier_normal
from flax.linen.module import _Sentinel # noqa
from flax.linen.module import Module, compact
from scico.flax.blocks import (
ConvBNBlock,
ConvBNMultiBlock,
ConvBNPoolBlock,
ConvBNUpsampleBlock,
upscale_nn,
)
from scico.numpy import Array
# The imports of Scope and _Sentinel (above) are required to silence
# "cannot resolve forward reference" warnings when building sphinx api
# docs.
ModuleDef = Any
class DnCNNNet(Module):
r"""Flax implementation of DnCNN :cite:`zhang-2017-dncnn`.
Flax implementation of the convolutional neural network (CNN)
architecture for denoising described in :cite:`zhang-2017-dncnn`.
Attributes:
depth: Number of layers in the neural network.
channels: Number of channels of input tensor.
num_filters: Number of filters in the convolutional layers.
kernel_size: Size of the convolution filters.
strides: Convolution strides.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
act: Class of activation function to apply. Default:
:func:`~flax.linen.activation.relu`.
"""
depth: int
channels: int
num_filters: int = 64
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
dtype: Any = jnp.float32
act: Callable = relu
[docs]
@compact
def __call__(
self,
inputs: Array,
train: bool = True,
) -> Array:
"""Apply DnCNN denoiser.
Args:
inputs: The array to be transformed.
train: Flag to differentiate between training and testing stages.
Returns:
The denoised input.
"""
# Definition using arguments common to all convolutions.
conv = partial(
Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=kaiming_normal()
)
# Definition using arguments common to all batch normalizations.
norm = partial(
BatchNorm,
use_running_average=not train,
momentum=0.99,
epsilon=1e-5,
dtype=self.dtype,
)
# Definition and application of DnCNN model.
base = inputs
y = conv(
self.num_filters,
self.kernel_size,
strides=self.strides,
name="conv_start",
)(inputs)
y = self.act(y)
for _ in range(self.depth - 2):
y = ConvBNBlock(
self.num_filters,
conv=conv,
norm=norm,
act=self.act,
kernel_size=self.kernel_size,
strides=self.strides,
)(y)
y = conv(
self.channels,
self.kernel_size,
strides=self.strides,
name="conv_end",
)(y)
return base - y # residual-like network
class ResNet(Module):
"""Flax implementation of convolutional network with residual connection.
Net constructed from sucessive applications of convolution plus batch
normalization blocks and ending with residual connection (i.e. adding
the input to the output of the block).
Args:
depth: Depth of residual net.
channels: Number of channels of input tensor.
num_filters: Number of filters in the layers of the block.
Corresponds to the number of channels in the network
processing.
kernel_size: Size of the convolution filters.
strides: Convolution strides.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
depth: int
channels: int
num_filters: int = 64
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
dtype: Any = jnp.float32
[docs]
@compact
def __call__(self, x: Array, train: bool = True) -> Array:
"""Apply ResNet.
Args:
x: The array to be transformed.
train: Flag to differentiate between training and testing stages.
Returns:
The ResNet result.
"""
residual = x
# Definition using arguments common to all convolutions.
conv = partial(
Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=xavier_normal()
)
# Definition using arguments common to all batch normalizations.
norm = partial(
BatchNorm,
use_running_average=not train,
momentum=0.99,
epsilon=1e-5,
dtype=self.dtype,
)
act = relu
# Definition and application of ResNet.
for _ in range(self.depth - 1):
x = ConvBNBlock(
self.num_filters,
conv=conv,
norm=norm,
act=act,
kernel_size=self.kernel_size,
strides=self.strides,
)(x)
x = conv(
self.channels,
self.kernel_size,
strides=self.strides,
)(x)
x = norm()(x)
return x + residual
class ConvBNNet(Module):
"""Convolution and batch normalization net.
Net constructed from sucessive applications of convolution plus batch
normalization blocks. No residual connection.
Args:
depth: Depth of net.
channels: Number of channels of input tensor.
num_filters: Number of filters in the layers of the block.
Corresponds to the number of channels in the network
processing.
kernel_size: Size of the convolution filters.
strides: Convolution strides.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
depth: int
channels: int
num_filters: int = 64
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
dtype: Any = jnp.float32
[docs]
@compact
def __call__(self, x: Array, train: bool = True) -> Array:
"""Apply ConvBNNet.
Args:
x: The array to be transformed.
train: Flag to differentiate between training and testing stages.
Returns:
The ConvBNNet result.
"""
# Definition using arguments common to all convolutions.
conv = partial(
Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=xavier_normal()
)
# Definition using arguments common to all batch normalizations.
norm = partial(
BatchNorm,
use_running_average=not train,
momentum=0.99,
epsilon=1e-5,
dtype=self.dtype,
)
act = relu
# Definition and application of ConvBNNet.
for _ in range(self.depth - 1):
x = ConvBNBlock(
self.num_filters,
conv=conv,
norm=norm,
act=act,
kernel_size=self.kernel_size,
strides=self.strides,
)(x)
x = conv(
self.channels,
self.kernel_size,
strides=self.strides,
)(x)
x = norm()(x)
return x
class UNet(Module):
"""Flax implementation of U-Net model :cite:`ronneberger-2015-unet`.
Args:
depth: Depth of U-Net.
channels: Number of channels of input tensor.
num_filters: Number of filters in the convolutional layer of the
block. Corresponds to the number of channels in the network
processing.
kernel_size: Size of the convolution filters.
strides: Convolution strides.
block_depth: Number of processing layers per block.
window_shape: Window for reduction for pooling and downsampling.
upsampling: Factor for expanding.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
depth: int
channels: int
num_filters: int = 64
kernel_size: Tuple[int, int] = (3, 3)
strides: Tuple[int, int] = (1, 1)
block_depth: int = 2
window_shape: Tuple[int, int] = (2, 2)
upsampling: int = 2
dtype: Any = jnp.float32
[docs]
@compact
def __call__(self, x: Array, train: bool = True) -> Array:
"""Apply U-Net.
Args:
x: The array to be transformed.
train: Flag to differentiate between training and testing stages.
Returns:
The U-Net result.
"""
# Definition using arguments common to all convolutions.
conv = partial(
Conv, use_bias=False, padding="CIRCULAR", dtype=self.dtype, kernel_init=kaiming_normal()
)
# Definition using arguments common to all batch normalizations.
norm = partial(
BatchNorm,
use_running_average=not train,
momentum=0.99,
epsilon=1e-5,
dtype=self.dtype,
)
act = relu
# Definition of upscaling function.
upfn = partial(upscale_nn, scale=self.upsampling)
# Definition and application of U-Net.
x = ConvBNMultiBlock(
self.block_depth,
self.num_filters,
conv=conv,
norm=norm,
act=act,
kernel_size=self.kernel_size,
strides=self.strides,
)(x)
residual = []
# going down
j: int = 1
for _ in range(self.depth - 1):
residual.append(x) # for skip connections
x = ConvBNPoolBlock(
2 * j * self.num_filters,
conv=conv,
norm=norm,
act=act,
pool=max_pool,
kernel_size=self.kernel_size,
strides=self.strides,
window_shape=self.window_shape,
)(x)
x = ConvBNMultiBlock(
self.block_depth,
2 * j * self.num_filters,
conv=conv,
norm=norm,
act=act,
kernel_size=self.kernel_size,
strides=self.strides,
)(x)
j = 2 * j
# going up
j = j // 2 # undo last
res_ind = -1
for _ in range(self.depth - 1):
x = ConvBNUpsampleBlock(
j * self.num_filters,
conv=conv,
norm=norm,
act=act,
upfn=upfn,
kernel_size=self.kernel_size,
strides=self.strides,
)(x)
# skip connection
x = jnp.concatenate((residual[res_ind], x), axis=3)
x = ConvBNMultiBlock(
self.block_depth,
j * self.num_filters,
conv=conv,
norm=norm,
act=act,
kernel_size=self.kernel_size,
strides=self.strides,
)(x)
res_ind -= 1
j = j // 2
# final conv1x1
ksz_out = (1, 1)
x = conv(self.channels, ksz_out, strides=self.strides)(x)
return x