# -*- coding: utf-8 -*-# Copyright (C) 2022-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Utilities for checkpointing Flax models."""frompathlibimportPathfromtypingimportUnionimportjaximportorbax.checkpointfromflax.trainingimportorbax_utilsfrom.stateimportTrainStatefrom.typed_dictimportConfigDict
[docs]defcheckpoint_restore(state:TrainState,workdir:Union[str,Path],ok_no_ckpt:bool=False)->TrainState:"""Load model and optimiser state. Args: state: Flax train state which includes model and optimiser parameters. workdir: Checkpoint file or directory of checkpoints to restore from. ok_no_ckpt: Flag to indicate if a checkpoint is expected. If ``False``, an error is generated if a checkpoint is not found. Returns: A restored Flax train state updated from checkpoint file is returned. If no checkpoint files are present and checkpoints are not strictly expected it returns the passed-in `state` unchanged. Raises: FileNotFoundError: If a checkpoint is expected and is not found. """# Check if workdir is Path or convert to Pathworkdir_=workdirifisinstance(workdir_,str):workdir_=Path(workdir_)ifworkdir_.exists():orbax_checkpointer=orbax.checkpoint.PyTreeCheckpointer()checkpoint_manager=orbax.checkpoint.CheckpointManager(workdir_,orbax_checkpointer)step=checkpoint_manager.latest_step()ifstepisnotNone:target={"state":state,"config":{}}ckpt=checkpoint_manager.restore(step,items=target)state=ckpt["state"]elifnotok_no_ckpt:raiseFileNotFoundError("Could not read from checkpoint: "+str(workdir))returnstate
[docs]defcheckpoint_save(state:TrainState,config:ConfigDict,workdir:Union[str,Path]):"""Store model, model configuration, and optimiser state. Note that naming is slightly different to distinguish from Flax functions. Args: state: Flax train state which includes model and optimiser parameters. config: Python dictionary including model train configuration. workdir: Path in which to store checkpoint files. """ifjax.process_index()==0:orbax_checkpointer=orbax.checkpoint.PyTreeCheckpointer()# Bundle config and model parameters togetherckpt={"state":state,"config":config}save_args=orbax_utils.save_args_from_target(ckpt)options=orbax.checkpoint.CheckpointManagerOptions(max_to_keep=3,create=True)checkpoint_manager=orbax.checkpoint.CheckpointManager(workdir,orbax_checkpointer,options)step=int(state.step)checkpoint_manager.save(step,ckpt,save_kwargs={"save_args":save_args})