scico.flax.train.spectral#

Utils for spectral normalization of convolutional layers in Flax models.

Functions

conv(inputs, kernel)

Compute convolution betwen input and kernel.

estimate_spectral_norm(f, input_shape[, ...])

Estimate spectral norm of operator.

exact_spectral_norm(f, input_shape)

Compute spectral norm of operator.

spectral_normalization_conv(params, ...[, ...])

Normalize parameters of convolutional layer by its spectral norm.

Classes

CNN(kernel_size, kernel0, dtype[, parent, name])

Evaluation of convolution operator via Flax convolutional layer.

scico.flax.train.spectral.estimate_spectral_norm(f, input_shape, seed=0, n_steps=10, eps=1e-12)[source]#

Estimate spectral norm of operator.

This function estimates the spectral norm of an operator by estimating the singular vectors of the operator via the power iteration method and the transpose operator enabled by nested autodiff in JAX.

Parameters:
  • f (Callable) – Operator to compute spectral norm.

  • input_shape (Tuple[int, ...]) – Shape of input to operator.

  • seed (float) – Value to seed the random generation. Default: 0.

  • n_steps (int) – Number of power iterations to compute. Default: 10.

  • eps (float) – Small value to prevent divide by zero. Default: 1e-12.

Returns:

Spectral norm.

class scico.flax.train.spectral.CNN(kernel_size, kernel0, dtype, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Bases: Module

Evaluation of convolution operator via Flax convolutional layer.

Inheritance diagram of CNN

Evaluation of convolution operator via Flax implementation of a convolutional layer. This is form of convolution is used only for the estimation of the spectral norm of the operator. Therefore, the value of the kernel is provided too.

Attributes:
  • kernel_size – Size of the convolution filter.

  • kernel0 – Convolution filter.

  • dtype – Output type.

__call__(x)[source]#

Apply CNN layer.

Parameters:

x – The array to be convolved.

Returns:

The result of the convolution with kernel0.

scico.flax.train.spectral.conv(inputs, kernel)[source]#

Compute convolution betwen input and kernel.

The convolution is evaluated via a CNN Flax model.

Parameters:
  • inputs (Array) – Array to compute convolution.

  • kernel (Array) – Filter of the convolutional operator.

Return type:

Array

Returns:

Result of convolution of input with kernel.

scico.flax.train.spectral.spectral_normalization_conv(params, traversal, xshape, n_steps=10)[source]#

Normalize parameters of convolutional layer by its spectral norm.

Parameters:
  • params (Any) – Current model parameters.

  • traversal (ModelParamTraversal) – Utility to select model parameters.

  • xshape (Tuple[int, ...]) – Shape of input.

  • n_steps (int) – Number of power iterations to compute. Default: 10.

Return type:

Any

scico.flax.train.spectral.exact_spectral_norm(f, input_shape)[source]#

Compute spectral norm of operator.

This function computes the spectral norm of an operator via autodiff in JAX.

Parameters:
  • f – Operator to compute spectral norm.

  • input_shape – Shape of input to operator.

Returns:

Spectral norm.