# -*- coding: utf-8 -*-# Copyright (C) 2022-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Configuration of Flax Train State."""fromtypingimportAny,Optional,Tuple,Unionimportjaximportjax.numpyasjnpimportoptaxfromflax.trainingimporttrain_statefromscico.numpyimportArrayfromscico.typingimportShapefrom.typed_dictimportConfigDict,ModelVarDictModuleDef=AnyKeyArray=Union[Array,jax.Array]PyTree=AnyArrayTree=optax.Params
[docs]classTrainState(train_state.TrainState):"""Definition of Flax train state. Definition of Flax train state including `batch_stats` for batch normalization. """batch_stats:Any
[docs]definitialize(key:KeyArray,model:ModuleDef,ishape:Shape)->Tuple[PyTree,...]:"""Initialize Flax model. Args: key: A PRNGKey used as the random key. model: Flax model to train. ishape: Shape of signal (image) to process by `model`. Make sure that no batch dimension is included. Returns: Initial model parameters (including `batch_stats`). """input_shape=(1,ishape[0],ishape[1],model.channels)@jax.jitdefinit(*args):returnmodel.init(*args)variables=init({"params":key},jnp.ones(input_shape,model.dtype))if"batch_stats"invariables:returnvariables["params"],variables["batch_stats"]returnvariables["params"]
[docs]defcreate_basic_train_state(key:KeyArray,config:ConfigDict,model:ModuleDef,ishape:Shape,learning_rate_fn:optax._src.base.Schedule,variables0:Optional[ModelVarDict]=None,)->TrainState:"""Create Flax basic train state and initialize. Args: key: A PRNGKey used as the random key. config: Dictionary of configuration. The values to use correspond to keywords: `opt_type` and `momentum`. model: Flax model to train. ishape: Shape of signal (image) to process by `model`. Ensure that no batch dimension is included. variables0: Optional initial state of model parameters. If not provided a random initialization is performed. Default: ``None``. learning_rate_fn: A function that maps step counts to values. Returns: state: Flax train state which includes the model apply function, the model parameters and an Optax optimizer. """batch_stats=Noneifvariables0isNone:aux=initialize(key,model,ishape)iflen(aux)>1:params,batch_stats=auxelse:params=auxelse:params=variables0["params"]if"batch_stats"invariables0:batch_stats=variables0["batch_stats"]ifconfig["opt_type"]=="SGD":# Stochastic Gradient Descent optimiserif"momentum"inconfig:tx=optax.sgd(learning_rate=learning_rate_fn,momentum=config["momentum"],nesterov=True)else:tx=optax.sgd(learning_rate=learning_rate_fn)elifconfig["opt_type"]=="ADAM":# Adam optimisertx=optax.adam(learning_rate=learning_rate_fn,)elifconfig["opt_type"]=="ADAMW":# Adam with weight decay regularizationtx=optax.adamw(learning_rate=learning_rate_fn,)else:raiseNotImplementedError(f"Optimizer specified {config['opt_type']} has not been included in SCICO.")ifbatch_statsisNone:state=TrainState.create(apply_fn=model.apply,params=params,tx=tx,)else:state=TrainState.create(apply_fn=model.apply,params=params,tx=tx,batch_stats=batch_stats,)returnstate