# -*- coding: utf-8 -*-# Copyright (C) 2021-2023 by SCICO Developers# All rights reserved. BSD 3-clause License.# This file is part of the SCICO package. Details of the copyright and# user license can be found in the 'LICENSE' file distributed with the# package."""Convolutional neural network models implemented in Flax."""importwarningsfromtypingimportAny,Optionalwarnings.simplefilter(action="ignore",category=FutureWarning)fromflaximportserializationfromflax.linen.moduleimportModulefromscico.numpyimportArray,BlockArrayfromscico.typingimportShapePyTree=Anydefload_variables(filename:str)->PyTree:"""Load trained model variables. Args: filename: Name of file containing trained model variables. Returns: A tree-like structure containing the values of the model variables. """withopen(filename,"rb")asdata_file:bytes_input=data_file.read()variables=serialization.msgpack_restore(bytes_input)var_in={"params":variables["params"],"batch_stats":variables["batch_stats"]}returnvar_indefsave_variables(variables:PyTree,filename:str):"""Save trained model weights. Args: filename: Name of file to to which model variables should be saved. variables: Model variables to save. """bytes_output=serialization.msgpack_serialize(variables)withopen(filename,"wb")asdata_file:data_file.write(bytes_output)classFlaxMap:r"""A trained flax model."""def__init__(self,model:Module,variables:PyTree):r"""Initialize a :class:`FlaxMap` object. Args: model: Flax model to apply. variables: Parameters and batch stats of trained model. """self.model=modelself.variables=variablessuper().__init__()
[docs]def__call__(self,x:Array)->Array:r"""Apply trained flax model. Args: x: Input array. Returns: Output of flax model. """ifisinstance(x,BlockArray):raiseNotImplementedError# Add singleton to input as necessary:# scico typically works with (H x W) or (H x W x C) arrays# flax expects (K x H x W x C) arrays# H: spatial height W: spatial width# K: batch size C: channel sizexndim=x.ndimaxsqueeze:Optional[Shape]=Noneifxndim==2:x=x.reshape((1,)+x.shape+(1,))axsqueeze=(0,3)elifxndim==3:x=x.reshape((1,)+x.shape)axsqueeze=(0,)y=self.model.apply(self.variables,x,train=False,mutable=False)ify.ndim!=xndim:returny.squeeze(axis=axsqueeze)returny