scico.flax.train.spectral#
Utils for spectral normalization of convolutional layers in Flax models.
Functions
|
Compute convolution betwen input and kernel. |
|
Estimate spectral norm of operator. |
|
Compute spectral norm of operator. |
|
Normalize parameters of convolutional layer by its spectral norm. |
Classes
|
Evaluation of convolution operator via Flax convolutional layer. |
- scico.flax.train.spectral.estimate_spectral_norm(f, input_shape, seed=0, n_steps=10, eps=1e-12)[source]#
Estimate spectral norm of operator.
This function estimates the spectral norm of an operator by estimating the singular vectors of the operator via the power iteration method and the transpose operator enabled by nested autodiff in JAX.
- class scico.flax.train.spectral.CNN(kernel_size, kernel0, dtype, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Bases:
Module
Evaluation of convolution operator via Flax convolutional layer.
Evaluation of convolution operator via Flax implementation of a convolutional layer. This is form of convolution is used only for the estimation of the spectral norm of the operator. Therefore, the value of the kernel is provided too.
- Attributes:
kernel_size – Size of the convolution filter.
kernel0 – Convolution filter.
dtype – Output type.
- scico.flax.train.spectral.conv(inputs, kernel)[source]#
Compute convolution betwen input and kernel.
The convolution is evaluated via a CNN Flax model.