scico.flax.train.checkpoints¶
Utilities for checkpointing Flax models.
Functions
|
Load model and optimiser state. |
|
Store model, model configuration, and optimiser state. |
- scico.flax.train.checkpoints.checkpoint_restore(state, workdir, ok_no_ckpt=False)[source]¶
Load model and optimiser state.
- Parameters:
state (
TrainState
) – Flax train state which includes model and optimiser parameters.workdir (
Union
[str
,Path
]) – Checkpoint file or directory of checkpoints to restore from.ok_no_ckpt (
bool
) – Flag to indicate if a checkpoint is expected. IfFalse
, an error is generated if a checkpoint is not found.
- Return type:
- Returns:
A restored Flax train state updated from checkpoint file is returned. If no checkpoint files are present and checkpoints are not strictly expected it returns the passed-in state unchanged.
- Raises:
FileNotFoundError – If a checkpoint is expected and is not found.
- scico.flax.train.checkpoints.checkpoint_save(state, config, workdir)[source]¶
Store model, model configuration, and optimiser state.
Note that naming is slightly different to distinguish from Flax functions.
- Parameters:
state (
TrainState
) – Flax train state which includes model and optimiser parameters.config (
ConfigDict
) – Python dictionary including model train configuration.workdir (
Union
[str
,Path
]) – Path in which to store checkpoint files.