# -*- coding: utf-8 -*-# Copyright (C) 2022-2024 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Definition of typed dictionaries for objects in training functionality."""importsysfromtypingimportAny,Callable,Listifsys.version_info>=(3,8):fromtypingimportTypedDict# pylint: disable=no-name-in-moduleelse:fromtyping_extensionsimportTypedDictfromscico.numpyimportArrayPyTree=Any
[docs]classDataSetDict(TypedDict):"""Dictionary structure for training data sets. Definition of the dictionary structure expected for the training data sets."""image:Array# inputlabel:Array# output
classConfigDict(TypedDict):"""Dictionary structure for training parameters. Definition of the dictionary structure expected for specifying training parameters."""seed:floatopt_type:strmomentum:floatbatch_size:intnum_epochs:intbase_learning_rate:floatlr_decay_rate:floatwarmup_epochs:intsteps_per_eval:intlog_every_steps:intsteps_per_epoch:intsteps_per_checkpoint:intlog:boolworkdir:strcheckpointing:boolreturn_state:boollr_schedule:Callablecriterion:Callablecreate_train_state:Callabletrain_step_fn:Callableeval_step_fn:Callablemetrics_fn:Callablepost_lst:List[Callable]
[docs]classModelVarDict(TypedDict):"""Dictionary structure for Flax variables. Definition of the dictionary structure grouping all Flax model variables. """params:PyTreebatch_stats:PyTree
[docs]classMetricsDict(TypedDict,total=False):"""Dictionary structure for training metrics. Definition of the dictionary structure for metrics computed or updates made during training. """loss:floatsnr:floatlearning_rate:float