scico.flax#
Neural network models implemented in Flax and utility functions.
Functions
|
Return count of variables for the parameter dictionary. |
|
Create data iterator for training. |
|
Load trained model weights. |
|
Execute model application loop. |
|
Save trained model weights. |
Classes
|
Class for encapsulating Flax training configuration and execution. |
|
Dictionary structure for training parmeters. |
|
Convolution and batch normalization net. |
|
Flax implementation of DnCNN [58]. |
|
A trained flax model. |
|
Flax implementation of MoDL [1]. |
|
Flax implementation of ODP network [19]. |
|
Flax implementation of convolutional network with residual connection. |
|
Flax implementation of U-Net model [43]. |
- class scico.flax.FlaxMap(model, variables)#
Bases:
object
A trained flax model.
Initialize a
FlaxMap
object.- Parameters:
- scico.flax.load_weights(filename)#
Load trained model weights.
- scico.flax.save_weights(variables, filename)#
Save trained model weights.
- class scico.flax.ConvBNNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
Module
Convolution and batch normalization net.
Net constructed from sucessive applications of convolution plus batch normalization blocks. No residual connection.
- Parameters:
depth (
int
) – Depth of net.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the layers of the block. Corresponds to the number of channels in the network processing.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters. Default: 3x3.strides (
Tuple
[int
,int
]) – Convolution strides. Default: 1x1.
- class scico.flax.DnCNNNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, act=<scico.custom_jvp object>, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
Module
Flax implementation of DnCNN [58].
Flax implementation of the convolutional neural network (CNN) architecture for denoising described in [58].
- Attributes:
depth – Number of layers in the neural network.
channels – Number of channels of input tensor.
num_filters – Number of filters in the convolutional layers.
kernel_size – Size of the convolution filters. Default: (3, 3).
strides – Convolution strides. Default: (1, 1).
dtype – Output dtype. Default:
float32
.act – Class of activation function to apply. Default: nn.relu.
- class scico.flax.ResNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
Module
Flax implementation of convolutional network with residual connection.
Net constructed from sucessive applications of convolution plus batch normalization blocks and ending with residual connection (i.e. adding the input to the output of the block).
- Parameters:
depth (
int
) – Depth of residual net.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the layers of the block. Corresponds to the number of channels in the network processing.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters. Default: 3x3.strides (
Tuple
[int
,int
]) – Convolution strides. Default: 1x1.
- class scico.flax.UNet(depth, channels, num_filters=64, kernel_size=(3, 3), strides=(1, 1), block_depth=2, window_shape=(2, 2), upsampling=2, dtype=<class 'jax.numpy.float32'>, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
Module
Flax implementation of U-Net model [43].
- Parameters:
depth (
int
) – Depth of U-Net.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the network processing.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters. Default: 3x3.strides (
Tuple
[int
,int
]) – Convolution strides. Default: 1x1.block_depth (
int
) – Number of processing layers per block. Default: 2.window_shape (
Tuple
[int
,int
]) – Window for reduction for pooling and downsampling. Default: 2x2.upsampling (
int
) – Factor for expanding. Default: 2.
- class scico.flax.MoDLNet(operator, depth, channels, num_filters, block_depth, kernel_size=(3, 3), strides=(1, 1), lmbda_ini=0.5, dtype=<class 'jax.numpy.float32'>, cg_iter=10, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
Module
Flax implementation of MoDL [1].
Flax implementation of the model-based deep learning (MoDL) architecture for inverse problems described in [1].
- Parameters:
operator (
Any
) – Operator for computing forward and adjoint mappings.depth (
int
) – Depth of MoDL net. Default: 1.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.block_depth (
int
) – Number of layers in the computational block.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters. Default: (3, 3).strides (
Tuple
[int
,int
]) – Convolution strides. Default: (1, 1).lmbda_ini (
float
) – Initial value of the regularization weight lambda. Default: 0.5.cg_iter (
int
) – Number of iterations for cg solver. Default: 10.
- class scico.flax.ODPNet(operator, depth, channels, num_filters, block_depth, kernel_size=(3, 3), strides=(1, 1), alpha_ini=0.5, dtype=<class 'jax.numpy.float32'>, odp_block=<class 'scico.flax.inverse.ODPProxDnBlock'>, parent=<flax.linen.module._Sentinel object>, name=None)#
Bases:
Module
Flax implementation of ODP network [19].
Flax implementation of the unrolled optimization with deep priors (ODP) network for inverse problems described in [19]. It can be constructed with proximal gradient blocks or gradient descent blocks.
- Parameters:
operator (
Any
) – Operator for computing forward and adjoint mappings.depth (
int
) – Depth of MoDL net. Default: 1.channels (
int
) – Number of channels of input tensor.num_filters (
int
) – Number of filters in the convolutional layer of the block. Corresponds to the number of channels in the output tensor.block_depth (
int
) – Number of layers in the computational block.kernel_size (
Tuple
[int
,int
]) – Size of the convolution filters. Default: (3, 3).strides (
Tuple
[int
,int
]) – Convolution strides. Default: (1, 1).alpha_ini (
float
) – Initial value of the fidelity weight alpha. Default: 0.5.odp_block (
Callable
) – processing block to apply. DefaultODPProxDnBlock
.
- odp_block#
alias of
ODPProxDnBlock
- scico.flax.create_input_iter(key, dataset, batch_size, size_device_prefetch=2, dtype=<class 'jax.numpy.float32'>, train=True)#
Create data iterator for training.
Create data iterator for training by sharding and prefetching batches on device.
- Parameters:
key (
Array
) – A PRNGKey used for random data permutations.dataset (
DataSetDict
) – Dictionary of data for supervised training including images and labels.batch_size (
int
) – Size of batch for iterating through the data.size_device_prefetch (
int
) – Size of prefetch buffer. Default: 2.train (
bool
) – Flag indicating the type of iterator to construct and use. The iterator for training permutes data on each epoch while the iterator for testing passes through the data without permuting it. Default:True
.
- Return type:
- Returns:
Array-like data sharded to specific devices coming from an iterator built from the provided dataset.
- class scico.flax.ConfigDict(*args, **kwargs)#
Bases:
dict
Dictionary structure for training parmeters.
Definition of the dictionary structure expected for specifying training parameters.
- class scico.flax.BasicFlaxTrainer(config, model, train_ds, test_ds, variables0=None)#
Bases:
object
Class for encapsulating Flax training configuration and execution.
Initializer for
BasicFlaxTrainer
.Initializer for
BasicFlaxTrainer
to configure model training and evaluation loop. Construct a Flax train state (which includes the model apply function, the model parameters and an Optax optimizer). This uses data parallel training assuming sharded batched data.- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.model (
Any
) – Flax model to train.train_ds (
DataSetDict
) – Dictionary of training data (includes images and labels).test_ds (
DataSetDict
) – Dictionary of testing data (includes images and labels).variables0 (
Optional
[ModelVarDict
]) – Optional initial state of model parameters. Default:None
.
- checkpoint(state)[source]#
Checkpoint training state if enabled.
- Parameters:
state (
TrainState
) – Flax train state.
- configure_reporting(config)[source]#
Configure logging and checkpointing.
The parameters configured correspond to
- logflag: A flag for logging to the output terminal the
evolution of results. Default:
False
.
- workdir: Directory to write checkpoints. Default: execution
directory.
- checkpointing: A flag for checkpointing model state.
Default:
False
.
- return_state: A flag for returning the train state instead
of the model variables. Default:
False
, i.e. return model variables.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.
- configure_steps(config, len_train, len_test)[source]#
Configure training, evaluation and monitoring steps.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.len_train (
int
) – Number of samples in training set.len_test (
int
) – Number of samples in testing set.
- configure_training_functions(config)[source]#
Construct training functions.
Default functions are used if not specified in configuration.
The parameters configured correspond to
- lr_schedule: A function that creates an Optax learning rate
schedule. Default:
create_cnst_lr_schedule
.
- criterion: A function that specifies the loss being minimized
in training. Default:
mse_loss
.
- create_train_state: A function that creates a Flax train state
and initializes it. A train state object helps to keep optimizer and module functionality grouped for training. Default:
create_basic_train_state
.
- train_step_fn: A function that executes a training step.
Default:
train_step
, i.e. use the standard train step.
- eval_step_fn: A function that executes an eval step. Default:
eval_step
, i.e. use the standard eval step.
- metrics_fn: A function that computes metrics. Default:
compute_metrics
, i.e. use the standard compute metrics function.
- post_lst: List of postprocessing functions to apply to
parameter set after optimizer step (e.g. clip to a specified range, normalize, etc.).
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.
- construct_data_iterators(train_ds, test_ds, key, mdtype)[source]#
Construct iterators for training and testing (evaluation) sets.
- Parameters:
train_ds (
DataSetDict
) – Dictionary of training data (includes images and labels).test_ds (
DataSetDict
) – Dictionary of testing data (includes images and labels).key (
Array
) – A PRNGKey used as the random key.mdtype (
Any
) – Output type of Flax model to be trained.
- define_parallel_training_functions()[source]#
Construct parallel versions of training functions.
Construct parallel versions of training functions via
jax.pmap
.
- initialize_training_state(config, key, model, variables0=None)[source]#
Construct and initialize Flax train state.
A train state object helps to keep optimizer and module functionality grouped for training.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.key (
Array
) – A PRNGKey used as the random key.model (
Any
) – Flax model to train.variables0 (
Optional
[ModelVarDict
]) – Optional initial state of model parameters. Default:None
.
- log(logstr)[source]#
Print stats to output terminal if logging is enabled.
- Parameters:
logstr (
str
) – String to be logged.
- set_training_parameters(config, len_train, len_test)[source]#
Extract configuration parameters and construct training functions.
Parameters and functions are passed in the configuration dictionary. Default values are used when parameters are not included in configuration.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.len_train (
int
) – Number of samples in training set.len_test (
int
) – Number of samples in testing set.
- train()[source]#
Execute training loop.
- Return type:
- Returns:
Model variables extracted from TrainState and iteration stats object obtained after executing the training loop. Alternatively the TrainState can be returned directly instead of the model variables. Note that the iteration stats object is not
None
only if log is enabled when configuring the training loop.
- update_metrics(state, step, train_metrics, t0)[source]#
Compute metrics for current model state.
Metrics for training and testing (eval) sets are computed and stored in an iteration stats object. This is executed only if logging is enabled.
- Parameters:
state (
TrainState
) – Flax train state which includes the model apply function and the model parameters.step (
int
) – Current step in training.train_metrics (
List
[MetricsDict
]) – List of diagnostic statistics computed from training set.t0 – Time when training loop started.
- scico.flax.only_apply(config, model, test_ds, apply_fn=<function apply_fn>, variables=None)#
Execute model application loop.
- Parameters:
config (
ConfigDict
) – Hyperparameter configuration.model (
Any
) – Flax model to apply.test_ds (
DataSetDict
) – Dictionary of testing data (includes images and labels).apply_fn (
Callable
) – A hook for a function that applies current model. Default:apply_fn
, i.e. use the standard apply function.variables (
Optional
[ModelVarDict
]) – Model parameters to use for evaluation. Default:None
(i.e. read from checkpoint).
- Return type:
- Returns:
Output of model evaluated at the input provided in test_ds.
- Raises:
RuntimeError – If no model variables and no checkpoint are specified.
- scico.flax.count_parameters(params)#
Return count of variables for the parameter dictionary.
Modules
Flax implementation of different convolutional blocks. |
|
Data utility functions used by Flax example scripts. |
|
Flax implementation of different imaging inversion models. |
|
Utilities for training Flax models. |