Source code for scico.flax.train.learning_rate
# -*- coding: utf-8 -*-
# Copyright (C) 2022 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Learning rate schedulers."""
import optax
from .typed_dict import ConfigDict
[docs]def create_cnst_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule:
"""Create learning rate to be a constant specified
value.
Args:
config: Dictionary of configuration. The value to use corresponds
to the `base_learning_rate` keyword.
Returns:
schedule: A function that maps step counts to values.
"""
schedule = optax.constant_schedule(config["base_learning_rate"])
return schedule
[docs]def create_exp_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule:
"""Create learning rate schedule to have an exponential decay.
Args:
config: Dictionary of configuration. The values to use correspond
to `base_learning_rate`, `num_epochs`, `steps_per_epochs` and
`lr_decay_rate`.
Returns:
schedule: A function that maps step counts to values.
"""
decay_steps = config["num_epochs"] * config["steps_per_epoch"]
schedule = optax.exponential_decay(
config["base_learning_rate"], decay_steps, config["lr_decay_rate"]
)
return schedule
[docs]def create_cosine_lr_schedule(config: ConfigDict) -> optax._src.base.Schedule:
"""Create learning rate to follow a pre-specified schedule.
Create learning rate to follow a pre-specified schedule with warmup
and cosine stages.
Args:
config: Dictionary of configuration. The parameters to use
correspond to keywords: `base_learning_rate`, `num_epochs`,
`warmup_epochs` and `steps_per_epoch`.
Returns:
schedule: A function that maps step counts to values.
"""
# Warmup stage
warmup_fn = optax.linear_schedule(
init_value=0.0,
end_value=config["base_learning_rate"],
transition_steps=config["warmup_epochs"] * config["steps_per_epoch"],
)
# Cosine stage
cosine_epochs = max(config["num_epochs"] - config["warmup_epochs"], 1)
cosine_fn = optax.cosine_decay_schedule(
init_value=config["base_learning_rate"],
decay_steps=cosine_epochs * config["steps_per_epoch"],
)
schedule = optax.join_schedules(
schedules=[warmup_fn, cosine_fn],
boundaries=[config["warmup_epochs"] * config["steps_per_epoch"]],
)
return schedule