scico.flax.train.checkpoints#
Utilities for checkpointing Flax models.
Functions
|
Load model and optimiser state. |
|
Store model, model configuration, and optimiser state. |
- scico.flax.train.checkpoints.checkpoint_restore(state, workdir, ok_no_ckpt=False)[source]#
Load model and optimiser state.
- Parameters:
state (
TrainState
) – Flax train state which includes model and optimiser parameters.workdir (
Union
[str
,Path
]) – Checkpoint file or directory of checkpoints to restore from.ok_no_ckpt (
bool
) – Flag to indicate if a checkpoint is expected. Default: False, a checkpoint is expected and an error is generated.
- Return type:
- Returns:
A restored Flax train state updated from checkpoint file is returned. If no checkpoint files are present and checkpoints are not strictly expected it returns the passed-in state unchanged.
- Raises:
FileNotFoundError – If a checkpoint is expected and is not found.
- scico.flax.train.checkpoints.checkpoint_save(state, config, workdir)[source]#
Store model, model configuration, and optimiser state.
Note that naming is slightly different to distinguish from Flax functions.
- Parameters:
state (
TrainState
) – Flax train state which includes model and optimiser parameters.config (
ConfigDict
) – Python dictionary including model train configuration.workdir (
Union
[str
,Path
]) – str or pathlib-like path to store checkpoint files in.