scico.flax#

Neural network models implemented in Flax and utility functions.

Functions

count_parameters(params)

Return count of variables for the parameter dictionary.

create_input_iter(key, dataset, batch_size)

Create data iterator for training.

load_weights(filename)

Load trained model weights.

only_apply(config, model, test_ds[, ...])

Execute model application loop.

save_weights(variables, filename)

Save trained model weights.

Classes

BasicFlaxTrainer(config, model, train_ds, ...)

Class for encapsulating Flax training configuration and execution.

ConfigDict(*args, **kwargs)

Dictionary structure for training parmeters.

ConvBNNet(depth, channels[, num_filters, ...])

Convolution and batch normalization net.

DnCNNNet(depth, channels[, num_filters, ...])

Flax implementation of DnCNN [58].

FlaxMap(model, variables)

A trained flax model.

MoDLNet(operator, depth, channels, ...[, ...])

Flax implementation of MoDL [1].

ODPNet(operator, depth, channels, ...[, ...])

Flax implementation of ODP network [19].

ResNet(depth, channels[, num_filters, ...])

Flax implementation of convolutional network with residual connection.

UNet(depth, channels[, num_filters, ...])

Flax implementation of U-Net model [43].

class scico.flax.FlaxMap(model, variables)#

Bases: object

A trained flax model.

Initialize a FlaxMap object.

Parameters:
  • model (Module) – Flax model to apply.

  • variables (Any) – Parameters and batch stats of trained model.

__call__(x)[source]#

Apply trained flax model.

Parameters:

x (Array) – Input array.

Return type:

Array

Returns:

Output of flax model.

scico.flax.load_weights(filename)#

Load trained model weights.

Parameters:

filename (str) – Name of file containing parameters for trained model.

Return type:

Any

Returns:

A tree-like structure containing the values of the parameters of the model.

scico.flax.save_weights(variables, filename)#

Save trained model weights.

Parameters:
  • filename (str) – Name of file to save parameters of trained model.

  • variables (Any) – Parameters of model to save.

class scico.flax.ConvBNNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Convolution and batch normalization net.

Inheritance diagram of ConvBNNet

Net constructed from sucessive applications of convolution plus batch normalization blocks. No residual connection.

Parameters:
  • depth (int) – Depth of net.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the layers of the block. Corresponds to the number of channels in the network processing.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: 3x3.

  • strides (Tuple[int, int]) – Convolution strides. Default: 1x1.

  • dtype (Any) – Output dtype. Default: float32.

__call__(x, train=True)[source]#

Apply ConvBNNet.

Parameters:
  • x (Array) – The array to be transformed.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The ConvBNNet result.

class scico.flax.DnCNNNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, act=<scico.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Flax implementation of DnCNN [58].

Inheritance diagram of DnCNNNet

Flax implementation of the convolutional neural network (CNN) architecture for denoising described in [58].

Attributes:
  • depth – Number of layers in the neural network.

  • channels – Number of channels of input tensor.

  • num_filters – Number of filters in the convolutional layers.

  • kernel_size – Size of the convolution filters. Default: (3, 3).

  • strides – Convolution strides. Default: (1, 1).

  • dtype – Output dtype. Default: float32.

  • act – Class of activation function to apply. Default: nn.relu.

__call__(inputs, train=True)[source]#

Apply DnCNN denoiser.

Parameters:
  • inputs (Array) – The array to be transformed.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The denoised input.

act(x): Callable = <scico.custom_jvp object>#
class scico.flax.ResNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Flax implementation of convolutional network with residual connection.

Inheritance diagram of ResNet

Net constructed from sucessive applications of convolution plus batch normalization blocks and ending with residual connection (i.e. adding the input to the output of the block).

Parameters:
  • depth (int) – Depth of residual net.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the layers of the block. Corresponds to the number of channels in the network processing.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: 3x3.

  • strides (Tuple[int, int]) – Convolution strides. Default: 1x1.

  • dtype (Any) – Output dtype. Default: float32.

__call__(x, train=True)[source]#

Apply ResNet.

Parameters:
  • x (Array) – The array to be transformed.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The ResNet result.

class scico.flax.UNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), block_depth=2, window_shape=(2, 2), upsampling=2, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Flax implementation of U-Net model [43].

Inheritance diagram of UNet

Parameters:
  • depth (int) – Depth of U-Net.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the network processing.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: 3x3.

  • strides (Tuple[int, int]) – Convolution strides. Default: 1x1.

  • block_depth (int) – Number of processing layers per block. Default: 2.

  • window_shape (Tuple[int, int]) – Window for reduction for pooling and downsampling. Default: 2x2.

  • upsampling (int) – Factor for expanding. Default: 2.

  • dtype (Any) – Output dtype. Default: float32.

__call__(x, train=True)[source]#

Apply U-Net.

Parameters:
  • x (Array) – The array to be transformed.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The U-Net result.

class scico.flax.MoDLNet(operator, depth, channels, num_filters, block_depth, kernel_size=(3, 3), strides=(1, 1), lmbda_ini=0.5, dtype=<class 'jax.numpy.float32'>, cg_iter=10, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Flax implementation of MoDL [1].

Inheritance diagram of MoDLNet

Flax implementation of the model-based deep learning (MoDL) architecture for inverse problems described in [1].

Parameters:
  • operator (Any) – Operator for computing forward and adjoint mappings.

  • depth (int) – Depth of MoDL net. Default: 1.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • block_depth (int) – Number of layers in the computational block.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – Convolution strides. Default: (1, 1).

  • lmbda_ini (float) – Initial value of the regularization weight lambda. Default: 0.5.

  • dtype (Any) – Output dtype. Default: float32.

  • cg_iter (int) – Number of iterations for cg solver. Default: 10.

__call__(y, train=True)[source]#

Apply MoDL net for inversion.

Parameters:
  • y (Array) – The array with signal to invert.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The reconstructed signal.

class scico.flax.ODPNet(operator, depth, channels, num_filters, block_depth, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.5, dtype=<class 'jax.numpy.float32'>, odp_block=<class 'scico.flax.inverse.ODPProxDnBlock'>, parent=<flax.linen.module._Sentinel object>, name=None)#

Bases: Module

Flax implementation of ODP network [19].

Inheritance diagram of ODPNet

Flax implementation of the unrolled optimization with deep priors (ODP) network for inverse problems described in [19]. It can be constructed with proximal gradient blocks or gradient descent blocks.

Parameters:
  • operator (Any) – Operator for computing forward and adjoint mappings.

  • depth (int) – Depth of MoDL net. Default: 1.

  • channels (int) – Number of channels of input tensor.

  • num_filters (int) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.

  • block_depth (int) – Number of layers in the computational block.

  • kernel_size (Tuple[int, int]) – Size of the convolution filters. Default: (3, 3).

  • strides (Tuple[int, int]) – Convolution strides. Default: (1, 1).

  • alpha_ini (float) – Initial value of the fidelity weight alpha. Default: 0.5.

  • dtype (Any) – Output dtype. Default: float32.

  • odp_block (Callable) – processing block to apply. Default ODPProxDnBlock.

__call__(y, train=True)[source]#

Apply ODP net for inversion.

Parameters:
  • y (Array) – The array with signal to invert.

  • train (bool) – Flag to differentiate between training and testing stages.

Return type:

Array

Returns:

The reconstructed signal.

odp_block#

alias of ODPProxDnBlock

scico.flax.create_input_iter(key, dataset, batch_size, size_device_prefetch=2, dtype=<class 'jax.numpy.float32'>, train=True)#

Create data iterator for training.

Create data iterator for training by sharding and prefetching batches on device.

Parameters:
  • key (Array) – A PRNGKey used for random data permutations.

  • dataset (DataSetDict) – Dictionary of data for supervised training including images and labels.

  • batch_size (int) – Size of batch for iterating through the data.

  • size_device_prefetch (int) – Size of prefetch buffer. Default: 2.

  • dtype (Any) – Type of data to handle. Default: float32.

  • train (bool) – Flag indicating the type of iterator to construct and use. The iterator for training permutes data on each epoch while the iterator for testing passes through the data without permuting it. Default: True.

Return type:

Any

Returns:

Array-like data sharded to specific devices coming from an iterator built from the provided dataset.

class scico.flax.ConfigDict(*args, **kwargs)#

Bases: dict

Dictionary structure for training parmeters.

Inheritance diagram of ConfigDict

Definition of the dictionary structure expected for specifying training parameters.

class scico.flax.BasicFlaxTrainer(config, model, train_ds, test_ds, variables0=None)#

Bases: object

Class for encapsulating Flax training configuration and execution.

Initializer for BasicFlaxTrainer.

Initializer for BasicFlaxTrainer to configure model training and evaluation loop. Construct a Flax train state (which includes the model apply function, the model parameters and an Optax optimizer). This uses data parallel training assuming sharded batched data.

Parameters:
  • config (ConfigDict) – Hyperparameter configuration.

  • model (Any) – Flax model to train.

  • train_ds (DataSetDict) – Dictionary of training data (includes images and labels).

  • test_ds (DataSetDict) – Dictionary of testing data (includes images and labels).

  • variables0 (Optional[ModelVarDict]) – Optional initial state of model parameters. Default: None.

checkpoint(state)[source]#

Checkpoint training state if enabled.

Parameters:

state (TrainState) – Flax train state.

configure_reporting(config)[source]#

Configure logging and checkpointing.

The parameters configured correspond to

  • logflag: A flag for logging to the output terminal the

    evolution of results. Default: False.

  • workdir: Directory to write checkpoints. Default: execution

    directory.

  • checkpointing: A flag for checkpointing model state.

    Default: False.

  • return_state: A flag for returning the train state instead

    of the model variables. Default: False, i.e. return model variables.

Parameters:

config (ConfigDict) – Hyperparameter configuration.

configure_steps(config, len_train, len_test)[source]#

Configure training, evaluation and monitoring steps.

Parameters:
  • config (ConfigDict) – Hyperparameter configuration.

  • len_train (int) – Number of samples in training set.

  • len_test (int) – Number of samples in testing set.

configure_training_functions(config)[source]#

Construct training functions.

Default functions are used if not specified in configuration.

The parameters configured correspond to

  • lr_schedule: A function that creates an Optax learning rate

    schedule. Default: create_cnst_lr_schedule.

  • criterion: A function that specifies the loss being minimized

    in training. Default: mse_loss.

  • create_train_state: A function that creates a Flax train state

    and initializes it. A train state object helps to keep optimizer and module functionality grouped for training. Default: create_basic_train_state.

  • train_step_fn: A function that executes a training step.

    Default: train_step, i.e. use the standard train step.

  • eval_step_fn: A function that executes an eval step. Default:

    eval_step, i.e. use the standard eval step.

  • metrics_fn: A function that computes metrics. Default:

    compute_metrics, i.e. use the standard compute metrics function.

  • post_lst: List of postprocessing functions to apply to

    parameter set after optimizer step (e.g. clip to a specified range, normalize, etc.).

Parameters:

config (ConfigDict) – Hyperparameter configuration.

construct_data_iterators(train_ds, test_ds, key, mdtype)[source]#

Construct iterators for training and testing (evaluation) sets.

Parameters:
  • train_ds (DataSetDict) – Dictionary of training data (includes images and labels).

  • test_ds (DataSetDict) – Dictionary of testing data (includes images and labels).

  • key (Array) – A PRNGKey used as the random key.

  • mdtype (Any) – Output type of Flax model to be trained.

define_parallel_training_functions()[source]#

Construct parallel versions of training functions.

Construct parallel versions of training functions via jax.pmap.

initialize_training_state(config, key, model, variables0=None)[source]#

Construct and initialize Flax train state.

A train state object helps to keep optimizer and module functionality grouped for training.

Parameters:
  • config (ConfigDict) – Hyperparameter configuration.

  • key (Array) – A PRNGKey used as the random key.

  • model (Any) – Flax model to train.

  • variables0 (Optional[ModelVarDict]) – Optional initial state of model parameters. Default: None.

log(logstr)[source]#

Print stats to output terminal if logging is enabled.

Parameters:

logstr (str) – String to be logged.

set_training_parameters(config, len_train, len_test)[source]#

Extract configuration parameters and construct training functions.

Parameters and functions are passed in the configuration dictionary. Default values are used when parameters are not included in configuration.

Parameters:
  • config (ConfigDict) – Hyperparameter configuration.

  • len_train (int) – Number of samples in training set.

  • len_test (int) – Number of samples in testing set.

train()[source]#

Execute training loop.

Return type:

Tuple[Dict[str, Any], Optional[IterationStats]]

Returns:

Model variables extracted from TrainState and iteration stats object obtained after executing the training loop. Alternatively the TrainState can be returned directly instead of the model variables. Note that the iteration stats object is not None only if log is enabled when configuring the training loop.

update_metrics(state, step, train_metrics, t0)[source]#

Compute metrics for current model state.

Metrics for training and testing (eval) sets are computed and stored in an iteration stats object. This is executed only if logging is enabled.

Parameters:
  • state (TrainState) – Flax train state which includes the model apply function and the model parameters.

  • step (int) – Current step in training.

  • train_metrics (List[MetricsDict]) – List of diagnostic statistics computed from training set.

  • t0 – Time when training loop started.

scico.flax.only_apply(config, model, test_ds, apply_fn=<function apply_fn>, variables=None)#

Execute model application loop.

Parameters:
  • config (ConfigDict) – Hyperparameter configuration.

  • model (Any) – Flax model to apply.

  • test_ds (DataSetDict) – Dictionary of testing data (includes images and labels).

  • apply_fn (Callable) – A hook for a function that applies current model. Default: apply_fn, i.e. use the standard apply function.

  • variables (Optional[ModelVarDict]) – Model parameters to use for evaluation. Default: None (i.e. read from checkpoint).

Return type:

Tuple[Array, ModelVarDict]

Returns:

Output of model evaluated at the input provided in test_ds.

Raises:

RuntimeError – If no model variables and no checkpoint are specified.

scico.flax.count_parameters(params)#

Return count of variables for the parameter dictionary.

Parameters:

params (Any) – Flax model parameters.

Return type:

int

Returns:

The number of parameters in the model.

Modules

scico.flax.blocks

Flax implementation of different convolutional blocks.

scico.flax.examples

Data utility functions used by Flax example scripts.

scico.flax.inverse

Flax implementation of different imaging inversion models.

scico.flax.train

Utilities for training Flax models.