scico.flax.blocks#

Flax implementation of different convolutional blocks.

Functions

upscale_nn(x[, scale])

Nearest neighbor upscale for image batches of shape (N, H, W, C).

Classes

ConvBNBlock(num_filters, conv, norm, act[, ...])

Define convolution and batch normalization Flax block.

ConvBNMultiBlock(num_blocks, num_filters, ...)

Block constructed from sucessive applications of ConvBNBlock.

ConvBNPoolBlock(num_filters, conv, norm, ...)

Define convolution, batch normalization and pooling Flax block.

ConvBNUpsampleBlock(num_filters, conv, norm, ...)

Define convolution, batch normalization and upsample Flax block.

ConvBlock(num_filters, conv, act[, ...])

Define Flax convolution block.

class scico.flax.blocks.ConvBNBlock(num_filters, conv, norm, act, kernel_size=(3, 3), strides=(1, 1), parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Define convolution and batch normalization Flax block.

Inheritance diagram of ConvBNBlock

Parameters:
  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • conv (Any) – Flax module implementing the convolution layer to apply.

  • norm (Any) – Flax module implementing the batch normalization layer to apply.

  • act (Callable[..., Array]) – Flax function defining the activation operation to apply.

  • kernel_size (Tuple[int, int]) – A shape tuple defining the size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – A shape tuple defining the size of strides in convolution. Default: (1, 1).

__call__(inputs)[source]#

Apply convolution followed by normalization and activation.

Parameters:

inputs (Array) – The array to be transformed.

Return type:

Array

Returns:

The transformed input.

class scico.flax.blocks.ConvBlock(num_filters, conv, act, kernel_size=(3, 3), strides=(1, 1), parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Define Flax convolution block.

Inheritance diagram of ConvBlock

Parameters:
  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • conv (Any) – Flax module implementing the convolution layer to apply.

  • act (Callable[..., Array]) – Flax function defining the activation operation to apply.

  • kernel_size (Tuple[int, int]) – A shape tuple defining the size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – A shape tuple defining the size of strides in convolution. Default: (1, 1).

__call__(inputs)[source]#

Apply convolution followed by activation.

Parameters:

inputs (Array) – The array to be transformed.

Return type:

Array

Returns:

The transformed input.

class scico.flax.blocks.ConvBNPoolBlock(num_filters, conv, norm, act, pool, kernel_size, strides, window_shape, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Define convolution, batch normalization and pooling Flax block.

Inheritance diagram of ConvBNPoolBlock

Parameters:
  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • conv (Any) – Flax module implementing the convolution layer to apply.

  • norm (Any) – Flax module implementing the batch normalization layer to apply.

  • act (Callable[..., Array]) – Flax function defining the activation operation to apply.

  • pool (Callable[..., Array]) – Flax function defining the pooling operation to apply.

  • kernel_size (Tuple[int, int]) – A shape tuple defining the size of the convolution filters.

  • strides (Tuple[int, int]) – A shape tuple defining the size of strides in convolution.

  • window_shape (Tuple[int, int]) – A shape tuple defining the window to reduce over in the pooling operation.

__call__(inputs)[source]#

Apply convolution followed by normalization, activation and pooling.

Parameters:

inputs (Array) – The array to be transformed.

Return type:

Array

Returns:

The transformed input.

class scico.flax.blocks.ConvBNUpsampleBlock(num_filters, conv, norm, act, upfn, kernel_size, strides, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Define convolution, batch normalization and upsample Flax block.

Inheritance diagram of ConvBNUpsampleBlock

Parameters:
  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • conv (Any) – Flax module implementing the convolution layer to apply.

  • norm (Any) – Flax module implementing the batch normalization layer to apply.

  • act (Callable[..., Array]) – Flax function defining the activation operation to apply.

  • upfn (Callable[..., Array]) – Flax function defining the upsampling operation to apply.

  • kernel_size (Tuple[int, int]) – A shape tuple defining the size of the convolution filters.

  • strides (Tuple[int, int]) – A shape tuple defining the size of strides in convolution.

__call__(inputs)[source]#

Apply convolution followed by normalization, activation and upsampling.

Parameters:

inputs (Array) – The array to be transformed.

Return type:

Array

Returns:

The transformed input.

class scico.flax.blocks.ConvBNMultiBlock(num_blocks, num_filters, conv, norm, act, kernel_size=(3, 3), strides=(1, 1), parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Block constructed from sucessive applications of ConvBNBlock.

Inheritance diagram of ConvBNMultiBlock

Parameters:
  • num_blocks (int) – Number of convolutional batch normalization blocks to apply. Each block has its own parameters for convolution and batch normalization.

  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • conv (Any) – Flax module implementing the convolution layer to apply.

  • norm (Any) – Flax module implementing the batch normalization layer to apply.

  • act (Callable[..., Array]) – Flax function defining the activation operation to apply.

  • kernel_size (Tuple[int, int]) – A shape tuple defining the size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – A shape tuple defining the size of strides in convolution. Default: (1, 1).

__call__(x)[source]#

Apply sucessive convolution normalization and activation blocks.

Apply sucessive blocks, each one composed of convolution normalization and activation.

Parameters:

x (Array) – The array to be transformed.

Return type:

Array

Returns:

The transformed input.

scico.flax.blocks.upscale_nn(x, scale=2)[source]#

Nearest neighbor upscale for image batches of shape (N, H, W, C).

Parameters:
  • x (Array) – Input tensor of shape (N, H, W, C).

  • scale (int) – Integer scaling factor.

Return type:

Array

Returns:

Output tensor of shape (N, H * scale, W * scale, C).