scico.flax.train.steps#

Definition of steps to iterate during training or evaluation.

Functions

eval_step(state, batch, criterion, metrics_fn)

Evaluate current model state.

train_step(state, batch, learning_rate_fn, ...)

Perform a single data parallel training step.

train_step_post(state, batch, ...)

Perform a single data parallel training step with postprocessing.

scico.flax.train.steps.train_step(state, batch, learning_rate_fn, criterion, metrics_fn)[source]#

Perform a single data parallel training step.

Assumes sharded batched data. This function is intended to be used via BasicFlaxTrainer, not directly.

Parameters:
  • state (TrainState) – Flax train state which includes the model apply function, the model parameters and an Optax optimizer.

  • batch (DataSetDict) – Sharded and batched training data.

  • learning_rate_fn (Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]) – A function to map step counts to values. This is only used for display purposes (optax optimizers are stateless, so the current learning rate is not stored). The real learning rate schedule applied is the one defined when creating the Flax state. If a different object is passed here, then the displayed value will be inaccurate.

  • criterion (Callable) – A function that specifies the loss being minimized in training.

  • metrics_fn (Callable) – A function to evaluate quality of current model.

Return type:

Tuple[TrainState, MetricsDict]

Returns:

Updated parameters and diagnostic statistics.

scico.flax.train.steps.train_step_post(state, batch, learning_rate_fn, criterion, train_step_fn, metrics_fn, post_lst)[source]#

Perform a single data parallel training step with postprocessing.

A list of postprocessing functions (i.e. for spectral normalization or positivity condition, etc.) is applied after the gradient update. Assumes sharded batched data.

This function is intended to be used via BasicFlaxTrainer, not directly.

Parameters:
  • state (TrainState) – Flax train state which includes the model apply function, the model parameters and an Optax optimizer.

  • batch (DataSetDict) – Sharded and batched training data.

  • learning_rate_fn (Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]) – A function to map step counts to values.

  • criterion (Callable) – A function that specifies the loss being minimized in training.

  • train_step_fn (Callable) – A function that executes a training step.

  • metrics_fn (Callable) – A function to evaluate quality of current model.

  • post_lst (List[Callable]) – List of postprocessing functions to apply to parameter set after optimizer step (e.g. clip to a specified range, normalize, etc.).

Return type:

Tuple[TrainState, MetricsDict]

Returns:

Updated parameters, fulfilling additional constraints, and diagnostic statistics.

scico.flax.train.steps.eval_step(state, batch, criterion, metrics_fn)[source]#

Evaluate current model state.

Assumes sharded batched data. This function is intended to be used via BasicFlaxTrainer or only_evaluate, not directly.

Parameters:
  • state (TrainState) – Flax train state which includes the model apply function and the model parameters.

  • batch (DataSetDict) – Sharded and batched training data.

  • criterion (Callable) – Loss function.

  • metrics_fn (Callable) – A function to evaluate quality of current model.

Return type:

MetricsDict

Returns:

Current diagnostic statistics.