scico.flax.train.state#

Configuration of Flax Train State.

Functions

create_basic_train_state(key, config, model, ...)

Create Flax basic train state and initialize.

initialize(key, model, ishape)

Initialize Flax model.

Classes

TrainState(step, apply_fn, params, tx, ...)

Definition of Flax train state.

class scico.flax.train.state.TrainState(step, apply_fn, params, tx, opt_state, batch_stats)[source]#

Bases: TrainState

Definition of Flax train state.

Inheritance diagram of TrainState

Definition of Flax train state including batch_stats for batch normalization.

replace(**updates)#

“Returns a new object replacing the specified fields with new values.

scico.flax.train.state.initialize(key, model, ishape)[source]#

Initialize Flax model.

Parameters:
  • key (Array) – A PRNGKey used as the random key.

  • model (Any) – Flax model to train.

  • ishape (Tuple[int, ...]) – Shape of signal (image) to process by model. Make sure that no batch dimension is included.

Return type:

Tuple[Any, ...]

Returns:

Initial model parameters (including batch_stats).

scico.flax.train.state.create_basic_train_state(key, config, model, ishape, learning_rate_fn, variables0=None)[source]#

Create Flax basic train state and initialize.

Parameters:
  • key (Array) – A PRNGKey used as the random key.

  • config (ConfigDict) – Dictionary of configuration. The values to use correspond to keywords: opt_type and momentum.

  • model (Any) – Flax model to train.

  • ishape (Tuple[int, ...]) – Shape of signal (image) to process by model. Ensure that no batch dimension is included.

  • variables0 (Optional[ModelVarDict]) – Optional initial state of model parameters. If not provided a random initialization is performed. Default: None.

  • learning_rate_fn (Callable[[Union[Array, ndarray, bool_, number, float, int]], Union[Array, ndarray, bool_, number, float, int]]) – A function that maps step counts to values.

Return type:

TrainState

Returns:

state

Flax train state which includes the model apply function,

the model parameters and an Optax optimizer.