scico.flax.train.state

Configuration of Flax Train State.

Functions

create_basic_train_state(key, config, model, ...)

Create Flax basic train state and initialize.

initialize(key, model, ishape)

Initialize Flax model.

Classes

TrainState(step, apply_fn, params, tx, ...)

Definition of Flax train state.

class scico.flax.train.state.TrainState(step, apply_fn, params, tx, opt_state, batch_stats)[source]

Bases: TrainState

Definition of Flax train state.

Inheritance diagram of TrainState

Definition of Flax train state including batch_stats for batch normalization.

replace(**updates)

Returns a new object replacing the specified fields with new values.

scico.flax.train.state.initialize(key, model, ishape)[source]

Initialize Flax model.

Parameters:
  • key (Array) – A PRNGKey used as the random key.

  • model (Any) – Flax model to train.

  • ishape (Tuple[int, ...]) – Shape of signal (image) to process by model. Make sure that no batch dimension is included.

Return type:

Tuple[Any, ...]

Returns:

Initial model parameters (including batch_stats).

scico.flax.train.state.create_basic_train_state(key, config, model, ishape, learning_rate_fn, variables0=None)[source]

Create Flax basic train state and initialize.

Parameters:
Return type:

TrainState

Returns:

state

Flax train state which includes the model apply function,

the model parameters and an Optax optimizer.