scico.flax.train.trainer¶
- Class providing integrated access to functionality for training Flax
models.
Assumes sharded batched data and uses data parallel training.
Module Attributes
This function will average the inputs across all devices. |
Functions
This function will average the inputs across all devices. |
|
|
Sync the batch statistics across replicas. |
- scico.flax.train.trainer.sync_batch_stats(state)[source]¶
Sync the batch statistics across replicas.
- Return type:
- scico.flax.train.trainer.cross_replica_mean(x)¶
This function will average the inputs across all devices.