Source code for scico.flax.train.traversals
# -*- coding: utf-8 -*-
# Copyright (C) 2022 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Functionality to traverse, select, and update model parameters."""
from typing import Any
import jax.numpy as jnp
from flax.traverse_util import ModelParamTraversal
PyTree = Any
[docs]def construct_traversal(prmname: str) -> ModelParamTraversal:
"""Construct utility to select model parameters using a name filter.
Args:
prmname: Name of parameter to select.
Returns:
Flax utility to traverse and select model parameters.
"""
return ModelParamTraversal(lambda path, _: prmname in path)
[docs]def clip_positive(params: PyTree, traversal: ModelParamTraversal, minval: float = 1e-4) -> PyTree:
"""Clip parameters to positive range.
Args:
params: Current model parameters.
traversal: Utility to select model parameters.
minval: Minimum value to clip selected model parameters and keep
them in a positive range. Default: 1e-4.
"""
params_out = traversal.update(lambda x: jnp.clip(x, a_min=minval), params)
return params_out
[docs]def clip_range(
params: PyTree, traversal: ModelParamTraversal, minval: float = 1e-4, maxval: float = 1
) -> PyTree:
"""Clip parameters to specified range.
Args:
params: Current model parameters.
traversal: Utility to select model parameters.
minval: Minimum value to clip selected model parameters.
Default: 1e-4.
maxval: Maximum value to clip selected model parameters.
Default: 1.
"""
params_out = traversal.update(lambda x: jnp.clip(x, a_min=minval, a_max=maxval), params)
return params_out