# -*- coding: utf-8 -*-# Copyright (C) 2022-2024 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Functionality to traverse, select, and update model parameters."""fromtypingimportAnyimportjax.numpyasjnpfromflax.traverse_utilimportModelParamTraversalPyTree=Any
[docs]defconstruct_traversal(prmname:str)->ModelParamTraversal:"""Construct utility to select model parameters using a name filter. Args: prmname: Name of parameter to select. Returns: Flax utility to traverse and select model parameters. """returnModelParamTraversal(lambdapath,_:prmnameinpath)
[docs]defclip_positive(params:PyTree,traversal:ModelParamTraversal,minval:float=1e-4)->PyTree:"""Clip parameters to positive range. Args: params: Current model parameters. traversal: Utility to select model parameters. minval: Minimum value to clip selected model parameters and keep them in a positive range. Default: 1e-4. """params_out=traversal.update(lambdax:jnp.clip(x,minval),params)returnparams_out
[docs]defclip_range(params:PyTree,traversal:ModelParamTraversal,minval:float=1e-4,maxval:float=1)->PyTree:"""Clip parameters to specified range. Args: params: Current model parameters. traversal: Utility to select model parameters. minval: Minimum value to clip selected model parameters. Default: 1e-4. maxval: Maximum value to clip selected model parameters. Default: 1. """params_out=traversal.update(lambdax:jnp.clip(x,minval,maxval),params)returnparams_out