scico.flax¶
Neural network models implemented in Flax and utility functions.
Many of the function and parameter names used in this sub-package are based on the somewhat non-standard Flax terminology for neural network components:
- model
The model is an abstract representation of the network structure that does not include specific weight values.
- parameters
The parameters of a model are the weights of the network represented by the model.
- variables
The variables encompass both the parameters (i.e. network weights) and secondary values that are set from training data, such as layer-dependent statistics used in batch normalization.
- state
The state encompasses both a set of model parameters as well as optimizer parameters involved in training of that model. Storing the state rather than just the variables enables a warm start for additional training.
Functions
|
Return count of variables for the parameter dictionary. |
|
Create data iterator for training. |
|
Load trained model variables. |
|
Execute model application loop. |
|
Save trained model weights. |
Classes
|
Class encapsulating Flax training configuration and execution. |
Dictionary structure for training parameters. |
|
|
Convolution and batch normalization net. |
|
Flax implementation of DnCNN [59]. |
|
A trained flax model. |
|
Flax implementation of MoDL [1]. |
|
Flax implementation of ODP network [19]. |
|
Flax implementation of convolutional network with residual connection. |
|
Flax implementation of U-Net model [44]. |
- class scico.flax.FlaxMap(model, variables)¶
Bases:
object
A trained flax model.
Initialize a
FlaxMap
object.- Parameters:
- scico.flax.load_variables(filename)¶
Load trained model variables.
- scico.flax.save_variables(variables, filename)¶
Save trained model weights.
- class scico.flax.ConvBNNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
Module
Convolution and batch normalization net.
Net constructed from sucessive applications of convolution plus batch normalization blocks. No residual connection.
- class scico.flax.DnCNNNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, act=<scico.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
Module
Flax implementation of DnCNN [59].
Flax implementation of the convolutional neural network (CNN) architecture for denoising described in [59].
- Attributes:
depth – Number of layers in the neural network.
channels – Number of channels of input tensor.
num_filters – Number of filters in the convolutional layers.
kernel_size – Size of the convolution filters.
strides – Convolution strides.
dtype – Output dtype. Default:
float32
.act – Class of activation function to apply. Default:
relu
.
- class scico.flax.ResNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
Module
Flax implementation of convolutional network with residual connection.
Net constructed from sucessive applications of convolution plus batch normalization blocks and ending with residual connection (i.e. adding the input to the output of the block).
- class scico.flax.UNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), block_depth=2, window_shape=(2, 2), upsampling=2, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
Module
Flax implementation of U-Net model [44].
- Parameters:
depth (
int
) – Depth of U-Net.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the network processing.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters.block_depth (
int
) – Number of processing layers per block.window_shape (
Tuple
[int
,int
]) – Window for reduction for pooling and downsampling.upsampling (
int
) – Factor for expanding.
- class scico.flax.MoDLNet(operator, depth, channels, num_filters, block_depth, kernel_size=(3, 3), strides=(1, 1), lmbda_ini=0.5, dtype=<class 'jax.numpy.float32'>, cg_iter=10, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
Module
Flax implementation of MoDL [1].
Flax implementation of the model-based deep learning (MoDL) architecture for inverse problems described in [1].
- Parameters:
operator (
Any
) – Operator for computing forward and adjoint mappings.depth (
int
) – Depth of MoDL net.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.block_depth (
int
) – Number of layers in the computational block.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters.lmbda_ini (
float
) – Initial value of the regularization weight lambda.cg_iter (
int
) – Number of iterations for cg solver.
- class scico.flax.ODPNet(operator, depth, channels, num_filters, block_depth, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.5, dtype=<class 'jax.numpy.float32'>, odp_block=<class 'scico.flax.inverse.ODPProxDnBlock'>, parent=<flax.linen.module._Sentinel object>, name=None)¶
Bases:
Module
Flax implementation of ODP network [19].
Flax implementation of the unrolled optimization with deep priors (ODP) network for inverse problems described in [19]. It can be constructed with proximal gradient blocks or gradient descent blocks.
- Parameters:
operator (
Any
) – Operator for computing forward and adjoint mappings.depth (
int
) – Depth of MoDL net.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.block_depth (
int
) – Number of layers in the computational block.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters.alpha_ini (
float
) – Initial value of the fidelity weight alpha.odp_block (
Callable
) – processing block to apply. DefaultODPProxDnBlock
.
- odp_block¶
alias of
ODPProxDnBlock
- scico.flax.create_input_iter(key, dataset, batch_size, size_device_prefetch=2, dtype=<class 'jax.numpy.float32'>, train=True)¶
Create data iterator for training.
Create data iterator for training by sharding and prefetching batches on device.
- Parameters:
key (
Array
) – A PRNGKey used for random data permutations.dataset (
DataSetDict
) – Dictionary of data for supervised training including images and labels.batch_size (
int
) – Size of batch for iterating through the data.size_device_prefetch (
int
) – Size of prefetch buffer. Default: 2.train (
bool
) – Flag indicating the type of iterator to construct and use. The iterator for training permutes data on each epoch while the iterator for testing passes through the data without permuting it. Default:True
.
- Return type:
- Returns:
Array-like data sharded to specific devices coming from an iterator built from the provided dataset.
- class scico.flax.ConfigDict¶
Bases:
TypedDict
Dictionary structure for training parameters.
Definition of the dictionary structure expected for specifying training parameters.
- class scico.flax.BasicFlaxTrainer(config, model, train_ds, test_ds, variables0=None)¶
Bases:
object
Class encapsulating Flax training configuration and execution.
Initializer for
BasicFlaxTrainer
.Initializer for
BasicFlaxTrainer
to configure model training and evaluation loop. Construct a Flax train state (which includes the model apply function, the model parameters and an Optax optimizer). This uses data parallel training assuming sharded batched data.- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.model (
Any
) – Flax model to train.train_ds (
DataSetDict
) – Dictionary of training data (includes images and labels).test_ds (
DataSetDict
) – Dictionary of testing data (includes images and labels).variables0 (
Optional
[ModelVarDict
]) – Optional initial state of model parameters.
- checkpoint(state)[source]¶
Checkpoint training state if enabled.
- Parameters:
state (
TrainState
) – Flax train state.
- configure_reporting(config)[source]¶
Configure logging and checkpointing.
The parameters configured correspond to
- logflag: A flag for logging to the output terminal the
evolution of results. Default:
False
.
- workdir: Directory to write checkpoints. Default: execution
directory.
- checkpointing: A flag for checkpointing model state.
Default:
False
.
- return_state: A flag for returning the train state instead
of the model variables. Default:
False
, i.e. return model variables.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.
- configure_steps(config, len_train, len_test)[source]¶
Configure training, evaluation and monitoring steps.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.len_train (
int
) – Number of samples in training set.len_test (
int
) – Number of samples in testing set.
- configure_training_functions(config)[source]¶
Construct training functions.
Default functions are used if not specified in configuration.
The parameters configured correspond to
- lr_schedule: A function that creates an Optax learning rate
schedule. Default:
create_cnst_lr_schedule
.
- criterion: A function that specifies the loss being minimized
in training. Default:
mse_loss
.
- create_train_state: A function that creates a Flax train state
and initializes it. A train state object helps to keep optimizer and module functionality grouped for training. Default:
create_basic_train_state
.
- train_step_fn: A function that executes a training step.
Default:
train_step
, i.e. use the standard train step.
- eval_step_fn: A function that executes an eval step. Default:
eval_step
, i.e. use the standard eval step.
- metrics_fn: A function that computes metrics. Default:
compute_metrics
, i.e. use the standard compute metrics function.
- post_lst: List of postprocessing functions to apply to
parameter set after optimizer step (e.g. clip to a specified range, normalize, etc.).
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.
- construct_data_iterators(train_ds, test_ds, key, mdtype)[source]¶
Construct iterators for training and testing (evaluation) sets.
- Parameters:
train_ds (
DataSetDict
) – Dictionary of training data (includes images and labels).test_ds (
DataSetDict
) – Dictionary of testing data (includes images and labels).key (
Array
) – A PRNGKey used as the random key.mdtype (
Any
) – Output type of Flax model to be trained.
- define_parallel_training_functions()[source]¶
Construct parallel versions of training functions.
Construct parallel versions of training functions via
jax.pmap
.
- initialize_training_state(config, key, model, variables0=None)[source]¶
Construct and initialize Flax train state.
A train state object helps to keep optimizer and module functionality grouped for training.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.key (
Array
) – A PRNGKey used as the random key.model (
Any
) – Flax model to train.variables0 (
Optional
[ModelVarDict
]) – Optional initial state of model parameters.
- log(logstr)[source]¶
Print stats to output terminal if logging is enabled.
- Parameters:
logstr (
str
) – String to be logged.
- set_training_parameters(config, len_train, len_test)[source]¶
Extract configuration parameters and construct training functions.
Parameters and functions are passed in the configuration dictionary. Default values are used when parameters are not included in configuration.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.len_train (
int
) – Number of samples in training set.len_test (
int
) – Number of samples in testing set.
- train()[source]¶
Execute training loop.
- Return type:
- Returns:
Model variables extracted from
TrainState
and iteration stats object obtained after executing the training loop. Alternatively theTrainState
can be returned directly instead of the model variables. Note that the iteration stats object is notNone
only if log is enabled when configuring the training loop.
- update_metrics(state, step, train_metrics, t0)[source]¶
Compute metrics for current model state.
Metrics for training and testing (eval) sets are computed and stored in an iteration stats object. This is executed only if logging is enabled.
- Parameters:
state (
TrainState
) – Flax train state which includes the model apply function and the model parameters.step (
int
) – Current step in training.train_metrics (
List
[MetricsDict
]) – List of diagnostic statistics computed from training set.t0 – Time when training loop started.
- scico.flax.only_apply(config, model, test_ds, apply_fn=<function apply_fn>, variables=None)¶
Execute model application loop.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.model (
Any
) – Flax model to apply.test_ds (
DataSetDict
) – Dictionary of testing data (includes images and labels).apply_fn (
Callable
) – A hook for a function that applies current model. Default:apply_fn
, i.e. use the standard apply function.variables (
Optional
[ModelVarDict
]) – Model parameters to use for evaluation. Default:None
(i.e. read from checkpoint).
- Return type:
- Returns:
Output of model evaluated at the input provided in test_ds.
- Raises:
RuntimeError – If no model variables and no checkpoint are specified.
- scico.flax.count_parameters(params)¶
Return count of variables for the parameter dictionary.
Modules
Flax implementation of different convolutional blocks. |
|
Data utility functions used by Flax example scripts. |
|
Flax implementation of different imaging inversion models. |
|
Utilities for training Flax models. |