scico.flax.train.apply#

Functionality to evaluate Flax trained model.

Uses data parallel evaluation.

Functions

apply_fn(model, variables, batch)

Apply current model.

scico.flax.train.apply.apply_fn(model, variables, batch)[source]#

Apply current model.

Assumes sharded batched data and replicated variables for distributed processing.

This function is intended to be used via only_apply, not directly.

Parameters:
  • model (Any) – Flax model to apply.

  • variables (ModelVarDict) – State of model parameters (replicated).

  • batch (DataSetDict) – Sharded and batched training data.

Return type:

Array

Returns:

Output computed by given model.