# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Functionality to evaluate Flax trained model.
Uses data parallel evaluation.
"""
from typing import Any, Callable, Optional, Tuple
import jax
import jax.numpy as jnp
from flax import jax_utils
from scico.flax import create_input_iter
from scico.numpy import Array
from .checkpoints import checkpoint_restore
from .clu_utils import get_parameter_overview
from .learning_rate import create_cnst_lr_schedule
from .state import create_basic_train_state
from .typed_dict import ConfigDict, DataSetDict, ModelVarDict
ModuleDef = Any
[docs]def apply_fn(model: ModuleDef, variables: ModelVarDict, batch: DataSetDict) -> Array:
"""Apply current model.
Assumes sharded batched data and replicated variables for
distributed processing.
This function is intended to be used via
:meth:`~scico.flax.only_apply`, not directly.
Args:
model: Flax model to apply.
variables: State of model parameters (replicated).
batch: Sharded and batched training data.
Returns:
Output computed by given model.
"""
output = model.apply(variables, batch["image"], train=False, mutable=False)
return output
def only_apply(
config: ConfigDict,
model: ModuleDef,
test_ds: DataSetDict,
apply_fn: Callable = apply_fn,
variables: Optional[ModelVarDict] = None,
) -> Tuple[Array, ModelVarDict]:
"""Execute model application loop.
Args:
config: Hyperparameter configuration.
model: Flax model to apply.
test_ds: Dictionary of testing data (includes images and
labels).
apply_fn: A hook for a function that applies current model.
Default: :meth:`~scico.flax.train.apply.apply_fn`, i.e. use
the standard apply function.
variables: Model parameters to use for evaluation. Default:
``None`` (i.e. read from checkpoint).
Returns:
Output of model evaluated at the input provided in `test_ds`.
Raises:
RuntimeError: If no model variables and no checkpoint are
specified.
"""
if "workdir" in config:
workdir: str = config["workdir"]
else:
workdir = "./"
if "checkpointing" in config:
checkpointing: bool = config["checkpointing"]
else:
checkpointing = False
# Configure seed.
key = jax.random.PRNGKey(config["seed"])
if variables is None:
if checkpointing: # pragma: no cover
ishape = test_ds["image"].shape[1:3]
lr_ = create_cnst_lr_schedule(config)
empty_state = create_basic_train_state(key, config, model, ishape, lr_)
state = checkpoint_restore(empty_state, workdir)
if hasattr(state, "batch_stats"):
variables = {
"params": state.params,
"batch_stats": state.batch_stats,
} # type: ignore
print(get_parameter_overview(variables["params"]))
print(get_parameter_overview(variables["batch_stats"]))
else:
variables = {"params": state.params, "batch_stats": {}}
print(get_parameter_overview(variables["params"]))
else:
raise RuntimeError("No variables or checkpoint provided.")
# For distributed testing
local_batch_size = config["batch_size"] // jax.process_count()
size_device_prefetch = 2 # Set for GPU
# Set data iterator
eval_dt_iter = create_input_iter(
key, # eval: no permutation
test_ds,
local_batch_size,
size_device_prefetch,
model.dtype,
train=False,
)
p_apply_step = jax.pmap(apply_fn, axis_name="batch", static_broadcasted_argnums=0)
# Evaluate model with provided variables
variables = jax_utils.replicate(variables)
num_examples = test_ds["image"].shape[0]
steps_ = num_examples // config["batch_size"]
output_lst = []
for _ in range(steps_):
eval_batch = next(eval_dt_iter)
output_batch = p_apply_step(model, variables, eval_batch)
output_lst.append(output_batch.reshape((-1,) + output_batch.shape[-3:]))
# Allow for completing the async run
jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
# Extract one copy of variables
variables = jax_utils.unreplicate(variables)
# Convert to array
output = jnp.array(output_lst)
# Remove leading dimension
output = output.reshape((-1,) + output.shape[-3:])
return output, variables # type: ignore