scico.flax.train.typed_dict

Definition of typed dictionaries for objects in training functionality.

Classes

ConfigDict

Dictionary structure for training parameters.

DataSetDict

Dictionary structure for training data sets.

MetricsDict

Dictionary structure for training metrics.

ModelVarDict

Dictionary structure for Flax variables.

class scico.flax.train.typed_dict.DataSetDict[source]

Bases: TypedDict

Dictionary structure for training data sets.

Inheritance diagram of DataSetDict

Definition of the dictionary structure expected for the training data sets.

image: Array

Input (Num. samples x Height x Width x Channels).

label: Array

Output (Num. samples x Height x Width x Channels) or (Num. samples x Classes).

class scico.flax.train.typed_dict.ConfigDict[source]

Bases: TypedDict

Dictionary structure for training parameters.

Inheritance diagram of ConfigDict

Definition of the dictionary structure expected for specifying training parameters.

seed: float

Value to initialize seed for random generation.

opt_type: str

SGD, ADAM, ADAMW.

Type:

Type of optimizer. Options

momentum: float

Momentum for SGD optimizer in case Nesterov is True.

batch_size: int

Size of batch for training.

num_epochs: int

Number of epochs for training (an epoch is one whole pass through the training dataset).

base_learning_rate: float

Starting learning rate for scheduling.

lr_decay_rate: float

Rate for decaying learning rate when scheduling is used.

warmup_epochs: int

Number of epochs if warmup scheduling is used.

steps_per_eval: int

Period of training steps to evaluate over test set.

log_every_steps: int

Period of training steps to print current train and test metrics.

steps_per_epoch: int

Training steps to be executed per epoch (depends on batch size).

steps_per_checkpoint: int

Period of training steps to save model (if checkpointing is True).

log: bool

Flag to indicate if evolution metrics are to be printed.

workdir: str

Path to directory for checkpointing model parameters.

checkpointing: bool

Flag to indicate if model parameters and optimizer state are to be stored while training.

return_state: bool

Flag to indicate if state (params and batch_stats) are to be returned at the end of training.

lr_schedule: Callable

Function to modify the learning rate while training (type optax schedule).

criterion: Callable

Criterion to optimize during training.

create_train_state: Callable

Function to create and initialize trainig state. Should include initialization of optimizer and of batch_stats (if applicable).

train_step_fn: Callable

Function to execute each training step.

eval_step_fn: Callable

Function to execute each evaluation step.

metrics_fn: Callable

Function to track metrics during training.

post_lst: List[Callable]

List of post-processing functions to apply after a train step (if any).

class scico.flax.train.typed_dict.ModelVarDict[source]

Bases: TypedDict

Dictionary structure for Flax variables.

Inheritance diagram of ModelVarDict

Definition of the dictionary structure grouping all Flax model variables.

params: Any

Model weights and biases.

batch_stats: Any

Batch statistics (e.g. normalization parameters that depend on training data).

class scico.flax.train.typed_dict.MetricsDict[source]

Bases: TypedDict

Dictionary structure for training metrics.

Inheritance diagram of MetricsDict

Definition of the dictionary structure for metrics computed or updates made during training.

loss: float

Evaluation of criterion being optimized.

snr: float

Evaluation of signal to noise ratio.

learning_rate: float

Current learning rate.