Source code for scico.flax.blocks

# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Flax implementation of different convolutional blocks."""

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

from typing import Any, Callable, Tuple

import jax.numpy as jnp

from flax.core import Scope  # noqa
from flax.linen.module import _Sentinel  # noqa
from flax.linen.module import Module, compact
from scico.numpy import Array

# The imports of Scope and _Sentinel (above) are required to silence
# "cannot resolve forward reference" warnings when building sphinx api
# docs.

ModuleDef = Any


[docs]class ConvBNBlock(Module): """Define convolution and batch normalization Flax block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. Default: (3, 3). strides: A shape tuple defining the size of strides in convolution. Default: (1, 1). """ num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1)
[docs] @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by normalization and activation. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) y = self.norm()(y) return self.act(y)
[docs]class ConvBlock(Module): """Define Flax convolution block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. Default: (3, 3). strides: A shape tuple defining the size of strides in convolution. Default: (1, 1). """ num_filters: int conv: ModuleDef act: Callable[..., Array] kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1)
[docs] @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by activation. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) return self.act(y)
[docs]class ConvBNPoolBlock(Module): """Define convolution, batch normalization and pooling Flax block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. pool: Flax function defining the pooling operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. strides: A shape tuple defining the size of strides in convolution. window_shape: A shape tuple defining the window to reduce over in the pooling operation. """ num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] pool: Callable[..., Array] kernel_size: Tuple[int, int] strides: Tuple[int, int] window_shape: Tuple[int, int]
[docs] @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by normalization, activation and pooling. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) y = self.norm()(y) y = self.act(y) # 'SAME': pads so as to have the same output shape as input if the stride is 1. return self.pool(y, self.window_shape, strides=self.window_shape, padding="SAME")
[docs]class ConvBNUpsampleBlock(Module): """Define convolution, batch normalization and upsample Flax block. Args: num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. upfn: Flax function defining the upsampling operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. strides: A shape tuple defining the size of strides in convolution. """ num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] upfn: Callable[..., Array] kernel_size: Tuple[int, int] strides: Tuple[int, int]
[docs] @compact def __call__( self, inputs: Array, ) -> Array: """Apply convolution followed by normalization, activation and upsampling. Args: inputs: The array to be transformed. Returns: The transformed input. """ y = self.conv( self.num_filters, self.kernel_size, strides=self.strides, )(inputs) y = self.norm()(y) y = self.act(y) return self.upfn(y)
[docs]class ConvBNMultiBlock(Module): """Block constructed from sucessive applications of :class:`ConvBNBlock`. Args: num_blocks: Number of convolutional batch normalization blocks to apply. Each block has its own parameters for convolution and batch normalization. num_filters: Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor. conv: Flax module implementing the convolution layer to apply. norm: Flax module implementing the batch normalization layer to apply. act: Flax function defining the activation operation to apply. kernel_size: A shape tuple defining the size of the convolution filters. Default: (3, 3). strides: A shape tuple defining the size of strides in convolution. Default: (1, 1). """ num_blocks: int num_filters: int conv: ModuleDef norm: ModuleDef act: Callable[..., Array] kernel_size: Tuple[int, int] = (3, 3) strides: Tuple[int, int] = (1, 1)
[docs] @compact def __call__( self, x: Array, ) -> Array: """Apply sucessive convolution normalization and activation blocks. Apply sucessive blocks, each one composed of convolution normalization and activation. Args: x: The array to be transformed. Returns: The transformed input. """ for _ in range(self.num_blocks): x = ConvBNBlock( self.num_filters, conv=self.conv, norm=self.norm, act=self.act, kernel_size=self.kernel_size, strides=self.strides, )(x) return x
[docs]def upscale_nn(x: Array, scale: int = 2) -> Array: """Nearest neighbor upscale for image batches of shape (N, H, W, C). Args: x: Input tensor of shape (N, H, W, C). scale: Integer scaling factor. Returns: Output tensor of shape (N, H * scale, W * scale, C). """ s = x.shape x = x.reshape((s[0],) + (s[1], 1, s[2], 1) + (s[3],)) x = jnp.tile(x, (1, 1, scale, 1, scale, 1)) return x.reshape((s[0],) + (scale * s[1], scale * s[2]) + (s[3],))