Source code for scico.flax._flax
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Convolutional neural network models implemented in Flax."""
import warnings
from typing import Any, Optional
warnings.simplefilter(action="ignore", category=FutureWarning)
from flax import serialization
from flax.linen.module import Module
from scico.numpy import Array, BlockArray
from scico.typing import Shape
PyTree = Any
def load_weights(filename: str) -> PyTree:
"""Load trained model weights.
Args:
filename: Name of file containing parameters for trained model.
Returns:
A tree-like structure containing the values of the parameters of
the model.
"""
with open(filename, "rb") as data_file:
bytes_input = data_file.read()
variables = serialization.msgpack_restore(bytes_input)
var_in = {"params": variables["params"], "batch_stats": variables["batch_stats"]}
return var_in
def save_weights(variables: PyTree, filename: str):
"""Save trained model weights.
Args:
filename: Name of file to save parameters of trained model.
variables: Parameters of model to save.
"""
bytes_output = serialization.msgpack_serialize(variables)
with open(filename, "wb") as data_file:
data_file.write(bytes_output)
class FlaxMap:
r"""A trained flax model."""
def __init__(self, model: Module, variables: PyTree):
r"""Initialize a :class:`FlaxMap` object.
Args:
model: Flax model to apply.
variables: Parameters and batch stats of trained model.
"""
self.model = model
self.variables = variables
super().__init__()
[docs] def __call__(self, x: Array) -> Array:
r"""Apply trained flax model.
Args:
x: Input array.
Returns:
Output of flax model.
"""
if isinstance(x, BlockArray):
raise NotImplementedError
# Add singleton to input as necessary:
# scico typically works with (H x W) or (H x W x C) arrays
# flax expects (K x H x W x C) arrays
# H: spatial height W: spatial width
# K: batch size C: channel size
xndim = x.ndim
axsqueeze: Optional[Shape] = None
if xndim == 2:
x = x.reshape((1,) + x.shape + (1,))
axsqueeze = (0, 3)
elif xndim == 3:
x = x.reshape((1,) + x.shape)
axsqueeze = (0,)
y = self.model.apply(self.variables, x, train=False, mutable=False)
if y.ndim != xndim:
return y.squeeze(axis=axsqueeze)
return y