# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.
"""Class providing integrated access to functionality for training Flax
models.
Assumes sharded batched data and uses data parallel training.
"""
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
import functools
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import jax
from jax import lax
from flax import jax_utils
from flax.training import common_utils
from scico.diagnostics import IterationStats
from scico.numpy import Array
from .checkpoints import checkpoint_restore, checkpoint_save
from .clu_utils import get_parameter_overview
from .diagnostics import ArgumentStruct, compute_metrics, stats_obj
from .input_pipeline import create_input_iter
from .learning_rate import create_cnst_lr_schedule
from .losses import mse_loss
from .state import TrainState, create_basic_train_state
from .steps import eval_step, train_step, train_step_post
from .typed_dict import ConfigDict, DataSetDict, MetricsDict, ModelVarDict
ModuleDef = Any
KeyArray = Union[Array, jax.Array]
PyTree = Any
DType = Any
# sync across replicas
[docs]def sync_batch_stats(state: TrainState) -> TrainState:
"""Sync the batch statistics across replicas."""
# Each device has its own version of the running average batch
# statistics and those are synced before evaluation
return state.replace(batch_stats=cross_replica_mean(state.batch_stats))
# pmean only works inside pmap because it needs an axis name.
#: This function will average the inputs across all devices.
cross_replica_mean = jax.pmap(lambda x: lax.pmean(x, "x"), "x")
class BasicFlaxTrainer:
"""Class for encapsulating Flax training configuration and execution."""
def __init__(
self,
config: ConfigDict,
model: ModuleDef,
train_ds: DataSetDict,
test_ds: DataSetDict,
variables0: Optional[ModelVarDict] = None,
):
"""Initializer for :class:`BasicFlaxTrainer`.
Initializer for :class:`BasicFlaxTrainer` to configure model
training and evaluation loop. Construct a Flax train state (which
includes the model apply function, the model parameters and an
Optax optimizer). This uses data parallel training assuming
sharded batched data.
Args:
config: Hyperparameter configuration.
model: Flax model to train.
train_ds: Dictionary of training data (includes images and
labels).
test_ds: Dictionary of testing data (includes images and
labels).
variables0: Optional initial state of model parameters.
Default: ``None``.
"""
# Configure seed
if "seed" not in config:
key = jax.random.PRNGKey(0)
else:
key = jax.random.PRNGKey(config["seed"])
# Split seed for data iterators and model initialization
key1, key2 = jax.random.split(key)
# Object for storing iteration stats
self.itstat_object: Optional[IterationStats] = None
# Configure training loop
len_train = train_ds["image"].shape[0]
len_test = test_ds["image"].shape[0]
self.set_training_parameters(config, len_train, len_test)
self.construct_data_iterators(train_ds, test_ds, key1, model.dtype)
self.define_parallel_training_functions()
self.initialize_training_state(config, key2, model, variables0)
# Store configuration
self.config = config
[docs] def set_training_parameters(
self,
config: ConfigDict,
len_train: int,
len_test: int,
):
"""Extract configuration parameters and construct training functions.
Parameters and functions are passed in the configuration
dictionary. Default values are used when parameters are not
included in configuration.
Args:
config: Hyperparameter configuration.
len_train: Number of samples in training set.
len_test: Number of samples in testing set.
"""
self.configure_steps(config, len_train, len_test)
self.configure_reporting(config)
self.configure_training_functions(config)
[docs] def construct_data_iterators(
self,
train_ds: DataSetDict,
test_ds: DataSetDict,
key: KeyArray,
mdtype: DType,
):
"""Construct iterators for training and testing (evaluation) sets.
Args:
train_ds: Dictionary of training data (includes images
and labels).
test_ds: Dictionary of testing data (includes images
and labels).
key: A PRNGKey used as the random key.
mdtype: Output type of Flax model to be trained.
"""
size_device_prefetch = 2 # Set for GPU
self.train_dt_iter = create_input_iter(
key,
train_ds,
self.local_batch_size,
size_device_prefetch,
mdtype,
train=True,
)
self.eval_dt_iter = create_input_iter(
key, # eval: no permutation
test_ds,
self.local_batch_size,
size_device_prefetch,
mdtype,
train=False,
)
self.ishape = train_ds["image"].shape[1:3]
self.log(
"Channels: %d, training signals: %d, testing"
" signals: %d, signal size: %d"
% (
train_ds["label"].shape[-1],
train_ds["label"].shape[0],
test_ds["label"].shape[0],
train_ds["label"].shape[1],
)
)
[docs] def define_parallel_training_functions(self):
"""Construct parallel versions of training functions.
Construct parallel versions of training functions via
:func:`jax.pmap`.
"""
if self.post_lst is not None:
self.p_train_step = jax.pmap(
functools.partial(
train_step_post,
train_step_fn=self.train_step_fn,
learning_rate_fn=self.lr_schedule,
criterion=self.criterion,
metrics_fn=self.metrics_fn,
post_lst=self.post_lst,
),
axis_name="batch",
)
else:
self.p_train_step = jax.pmap(
functools.partial(
self.train_step_fn,
learning_rate_fn=self.lr_schedule,
criterion=self.criterion,
metrics_fn=self.metrics_fn,
),
axis_name="batch",
)
self.p_eval_step = jax.pmap(
functools.partial(
self.eval_step_fn, criterion=self.criterion, metrics_fn=self.metrics_fn
),
axis_name="batch",
)
[docs] def initialize_training_state(
self,
config: ConfigDict,
key: KeyArray,
model: ModuleDef,
variables0: Optional[ModelVarDict] = None,
):
"""Construct and initialize Flax train state.
A train state object helps to keep optimizer and module
functionality grouped for training.
Args:
config: Hyperparameter configuration.
key: A PRNGKey used as the random key.
model: Flax model to train.
variables0: Optional initial state of model parameters.
Default: ``None``.
"""
# Create Flax training state
state = self.create_train_state(
key, config, model, self.ishape, self.lr_schedule, variables0
)
# Only restore if no initialization is provided
if self.checkpointing and variables0 is None:
ok_no_ckpt = True # It is ok if no checkpoint is found
state = checkpoint_restore(state, self.workdir, ok_no_ckpt)
self.log(get_parameter_overview(state.params))
if hasattr(state, "batch_stats"):
self.log(get_parameter_overview(state.batch_stats))
self.state = state
[docs] def train(self) -> Tuple[Dict[str, Any], Optional[IterationStats]]:
"""Execute training loop.
Returns:
Model variables extracted from TrainState and iteration
stats object obtained after executing the training loop.
Alternatively the TrainState can be returned directly instead
of the model variables. Note that the iteration stats object
is not ``None`` only if log is enabled when configuring the
training loop.
"""
state = self.state
step_offset = int(state.step) # > 0 if restarting from checkpoint
# For parallel training
state = jax_utils.replicate(state)
# Execute training loop and register stats
t0 = time.time()
self.log("Initial compilation, this might take some minutes...")
train_metrics: List[Any] = []
for step, batch in zip(range(step_offset, self.num_steps), self.train_dt_iter):
state, metrics = self.p_train_step(state, batch)
# Training metrics computed in step
train_metrics.append(metrics)
if step == step_offset:
self.log("Initial compilation completed.")
if (step + 1) % self.log_every_steps == 0:
# sync batch statistics across replicas
state = sync_batch_stats(state)
self.update_metrics(state, step, train_metrics, t0)
train_metrics = []
if (step + 1) % self.steps_per_checkpoint == 0 or step + 1 == self.num_steps:
# sync batch statistics across replicas
state = sync_batch_stats(state)
self.checkpoint(state)
# Wait for finishing asynchronous execution
jax.random.normal(jax.random.PRNGKey(0), ()).block_until_ready()
# Close object for iteration stats if logging
if self.logflag:
assert self.itstat_object is not None
self.itstat_object.end()
state = sync_batch_stats(state)
# Final checkpointing
self.checkpoint(state)
# Extract one copy of state
state = jax_utils.unreplicate(state)
if self.return_state:
return state, self.itstat_object # type: ignore
dvar: ModelVarDict = {
"params": state.params,
"batch_stats": state.batch_stats,
}
return dvar, self.itstat_object # type: ignore
[docs] def update_metrics(self, state: TrainState, step: int, train_metrics: List[MetricsDict], t0):
"""Compute metrics for current model state.
Metrics for training and testing (eval) sets are computed and
stored in an iteration stats object. This is executed only if
logging is enabled.
Args:
state: Flax train state which includes the model apply
function and the model parameters.
step: Current step in training.
train_metrics: List of diagnostic statistics computed from
training set.
t0: Time when training loop started.
"""
if not self.logflag:
return
eval_metrics: List[Any] = []
# Build summary dictionary for logging
# Include training stats
train_metrics = common_utils.get_metrics(train_metrics)
summary = {
f"train_{k}": v
for k, v in jax.tree_util.tree_map(lambda x: x.mean(), train_metrics).items()
}
epoch = step // self.steps_per_epoch
summary["epoch"] = epoch
summary["time"] = time.time() - t0
# Eval over testing set
for _ in range(self.steps_per_eval):
eval_batch = next(self.eval_dt_iter)
metrics = self.p_eval_step(state, eval_batch)
eval_metrics.append(metrics)
# Compute testing metrics
eval_metrics = common_utils.get_metrics(eval_metrics)
# Add testing stats to summary
summary_eval = jax.tree_util.tree_map(lambda x: x.mean(), eval_metrics)
summary.update(summary_eval)
# Update iteration stats object
assert isinstance(self.itstat_object, IterationStats) # for mypy
self.itstat_object.insert(self.itstat_insert_func(ArgumentStruct(**summary)))
[docs] def checkpoint(self, state: TrainState): # pragma: no cover
"""Checkpoint training state if enabled.
Args:
state: Flax train state.
"""
if self.checkpointing:
checkpoint_save(jax_utils.unreplicate(state), self.config, self.workdir)
[docs] def log(self, logstr: str):
"""Print stats to output terminal if logging is enabled.
Args:
logstr: String to be logged.
"""
if self.logflag:
print(logstr)