Source code for scico.flax.train.typed_dict

# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Definition of typed dictionaries for objects in training functionality."""

import sys
from typing import Any, Callable, List

if sys.version_info >= (3, 8):
    from typing import TypedDict  # pylint: disable=no-name-in-module
else:
    from typing_extensions import TypedDict

from scico.numpy import Array

PyTree = Any


[docs]class DataSetDict(TypedDict): """Dictionary structure for training data sets. Definition of the dictionary structure expected for the training data sets.""" image: Array # input label: Array # output
class ConfigDict(TypedDict): """Dictionary structure for training parmeters. Definition of the dictionary structure expected for specifying training parameters.""" seed: float opt_type: str momentum: float batch_size: int num_epochs: int base_learning_rate: float lr_decay_rate: float warmup_epochs: int steps_per_eval: int log_every_steps: int steps_per_epoch: int steps_per_checkpoint: int log: bool workdir: str checkpointing: bool return_state: bool lr_schedule: Callable criterion: Callable create_train_state: Callable train_step_fn: Callable eval_step_fn: Callable metrics_fn: Callable post_lst: List[Callable]
[docs]class ModelVarDict(TypedDict): """Dictionary structure for Flax variables. Definition of the dictionary structure grouping all Flax model variables. """ params: PyTree batch_stats: PyTree
[docs]class MetricsDict(TypedDict, total=False): """Dictionary structure for training metrics. Definition of the dictionary structure for metrics computed or updates made during training. """ loss: float snr: float learning_rate: float