scico.flax.train.state#
Configuration of Flax Train State.
Functions
|
Create Flax basic train state and initialize. |
|
Initialize Flax model. |
Classes
|
Definition of Flax train state. |
- class scico.flax.train.state.TrainState(step, apply_fn, params, tx, opt_state, batch_stats)[source]#
Bases:
TrainState
Definition of Flax train state.
Definition of Flax train state including batch_stats for batch normalization.
- replace(**updates)#
“Returns a new object replacing the specified fields with new values.
- scico.flax.train.state.initialize(key, model, ishape)[source]#
Initialize Flax model.
- Parameters:
- Return type:
- Returns:
Initial model parameters (including batch_stats).
- scico.flax.train.state.create_basic_train_state(key, config, model, ishape, learning_rate_fn, variables0=None)[source]#
Create Flax basic train state and initialize.
- Parameters:
key (
Array
) – A PRNGKey used as the random key.config (
ConfigDict
) – Dictionary of configuration. The values to use correspond to keywords: opt_type and momentum.model (
Any
) – Flax model to train.ishape (
Tuple
[int
,...
]) – Shape of signal (image) to process by model. Ensure that no batch dimension is included.variables0 (
Optional
[ModelVarDict
]) – Optional initial state of model parameters. If not provided a random initialization is performed. Default:None
.learning_rate_fn (
Callable
[[Union
[Array
,ndarray
,bool_
,number
,float
,int
]],Union
[Array
,ndarray
,bool_
,number
,float
,int
]]) – A function that maps step counts to values.
- Return type:
- Returns:
state –
- Flax train state which includes the model apply function,
the model parameters and an Optax optimizer.