scico.flax.train.checkpoints#

Utilities for checkpointing Flax models.

Functions

checkpoint_restore(state, workdir[, ok_no_ckpt])

Load model and optimiser state.

checkpoint_save(state, config, workdir)

Store model, model configuration, and optimiser state.

scico.flax.train.checkpoints.checkpoint_restore(state, workdir, ok_no_ckpt=False)[source]#

Load model and optimiser state.

Parameters:
  • state (TrainState) – Flax train state which includes model and optimiser parameters.

  • workdir (Union[str, Path]) – Checkpoint file or directory of checkpoints to restore from.

  • ok_no_ckpt (bool) – Flag to indicate if a checkpoint is expected. Default: False, a checkpoint is expected and an error is generated.

Return type:

TrainState

Returns:

A restored Flax train state updated from checkpoint file is returned. If no checkpoint files are present and checkpoints are not strictly expected it returns the passed-in state unchanged.

Raises:

FileNotFoundError – If a checkpoint is expected and is not found.

scico.flax.train.checkpoints.checkpoint_save(state, config, workdir)[source]#

Store model, model configuration, and optimiser state.

Note that naming is slightly different to distinguish from Flax functions.

Parameters:
  • state (TrainState) – Flax train state which includes model and optimiser parameters.

  • config (ConfigDict) – Python dictionary including model train configuration.

  • workdir (Union[str, Path]) – str or pathlib-like path to store checkpoint files in.