Source code for scico.flax.train.checkpoints
# -*- coding: utf-8 -*-
# Copyright (C) 2022 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Utilities for checkpointing Flax models."""
from pathlib import Path
from typing import Union
import jax
import orbax
from flax.training import orbax_utils
from .state import TrainState
from .typed_dict import ConfigDict
[docs]def checkpoint_restore(
state: TrainState, workdir: Union[str, Path], ok_no_ckpt: bool = False
) -> TrainState:
"""Load model and optimiser state.
Args:
state: Flax train state which includes model and optimiser
parameters.
workdir: Checkpoint file or directory of checkpoints to restore
from.
ok_no_ckpt: Flag to indicate if a checkpoint is expected. Default:
False, a checkpoint is expected and an error is generated.
Returns:
A restored Flax train state updated from checkpoint file is returned.
If no checkpoint files are present and checkpoints are not strictly
expected it returns the passed-in `state` unchanged.
Raises:
FileNotFoundError: If a checkpoint is expected and is not found.
"""
# Check if workdir is Path or convert to Path
workdir_ = workdir
if isinstance(workdir_, str):
workdir_ = Path(workdir_)
if workdir_.exists():
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpoint_manager = orbax.checkpoint.CheckpointManager(workdir_, orbax_checkpointer)
step = checkpoint_manager.latest_step()
if step is not None:
target = {"state": state, "config": {}}
ckpt = checkpoint_manager.restore(step, items=target)
state = ckpt["state"]
elif not ok_no_ckpt:
raise FileNotFoundError("Could not read from checkpoint: " + str(workdir))
return state
[docs]def checkpoint_save(state: TrainState, config: ConfigDict, workdir: Union[str, Path]):
"""Store model, model configuration, and optimiser state.
Note that naming is slightly different to distinguish from Flax
functions.
Args:
state: Flax train state which includes model and optimiser
parameters.
config: Python dictionary including model train configuration.
workdir: str or pathlib-like path to store checkpoint files in.
"""
if jax.process_index() == 0:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
# Bundle config and model parameters together
ckpt = {"state": state, "config": config}
save_args = orbax_utils.save_args_from_target(ckpt)
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=3, create=True)
checkpoint_manager = orbax.checkpoint.CheckpointManager(
workdir, orbax_checkpointer, options
)
step = int(state.step)
checkpoint_manager.save(step, ckpt, save_kwargs={"save_args": save_args})