scico.flax.train.trainer#

Class providing integrated access to functionality for training Flax

models.

Assumes sharded batched data and uses data parallel training.

Module Attributes

cross_replica_mean(x)

This function will average the inputs across all devices.

Functions

cross_replica_mean(x)

This function will average the inputs across all devices.

sync_batch_stats(state)

Sync the batch statistics across replicas.

scico.flax.train.trainer.sync_batch_stats(state)[source]#

Sync the batch statistics across replicas.

Return type:

TrainState

scico.flax.train.trainer.cross_replica_mean(x)#

This function will average the inputs across all devices.