Source code for scico.flax.train.clu_utils

# -*- coding: utf-8 -*-
# Copyright (C) 2022-2023 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Utilities for displaying Flax models."""

# These utilities have been copied from the Common Loop Utils (CLU)
# and have been modified to remove TensorFlow dependencies

# CLU is licensed under the Apache License, Version 2.0 (the "License");
# You may obtain a copy of the License at

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)

import dataclasses
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np

import jax

import flax

PyTree = Any
ParamsContainer = Union[Dict[str, np.ndarray], Mapping[str, Mapping[str, Any]]]

[docs]@dataclasses.dataclass class ParamRow: """Definition of the structure of a row for printing parameters without stats.""" name: str shape: Tuple[int] size: int
[docs]@dataclasses.dataclass class ParamRowWithStats(ParamRow): """Definition of the structure of a row for printing parameters with stats.""" mean: float std: float
[docs]def flatten_dict( input_dict: Dict[str, Any], prefix: str = "", delimiter: str = "/" ) -> Dict[str, Any]: """Flatten keys of a nested dictionary. Args: input_dict: Nested dictionary. prefix: Prefix of already flatten. Default: empty string. delimiter: Delimiter for displaying. Default: ``/``. Returns: A dictionary with the keys flattened. """ output_dict = {} for key, value in input_dict.items(): nested_key = f"{prefix}{delimiter}{key}" if prefix else key if isinstance(value, (dict, flax.core.FrozenDict)): output_dict.update(flatten_dict(value, prefix=nested_key, delimiter=delimiter)) else: output_dict[nested_key] = value return output_dict
def count_parameters(params: PyTree) -> int: """Return count of variables for the parameter dictionary. Args: params: Flax model parameters. Returns: The number of parameters in the model. """ flat_params = flatten_dict(params) return sum( for v in flat_params.values()) # type: ignore
[docs]def get_parameter_rows( params: ParamsContainer, *, include_stats: bool = False, ) -> List[Union[ParamRow, ParamRowWithStats]]: """Return information about parameters as a list of dictionaries. Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. include_stats: If ``True`` add columns with mean and std for each variable. Note that this can be considerably more compute intensive and cause a lot of memory to be transferred to the host. Returns: A list of `ParamRow`, or `ParamRowWithStats`, depending on the passed value of `include_stats`. """ assert isinstance(params, (dict, flax.core.FrozenDict)) if params: params = flatten_dict(params) names, values = map(list, tuple(zip(*sorted(params.items())))) else: names, values = [], [] def make_row(name, value): if include_stats: return ParamRowWithStats( name=name, shape=value.shape, size=int(, mean=float(value.mean()), std=float(value.std()), ) else: return ParamRow(name=name, shape=value.shape, size=int( return [make_row(name, value) for name, value in zip(names, values)]
def _default_table_value_formatter(value): """Format ints with "," between thousands and floats to 3 digits.""" if isinstance(value, bool): return str(value) elif isinstance(value, int): return "{:,}".format(value) elif isinstance(value, float): return "{:.3}".format(value) else: return str(value)
[docs]def make_table( rows: List[Any], *, column_names: Optional[Sequence[str]] = None, value_formatter: Callable[[Any], str] = _default_table_value_formatter, max_lines: Optional[int] = None, ) -> str: """Render list of rows to a table. Args: rows: List of dataclass instances of a single type (e.g. `ParamRow`). column_names: List of columns that that should be included in the output. If not provided, then the columns are taken from keys of the first row. value_formatter: Callable used to format cell values. max_lines: Don't render a table longer than this. Returns: A string representation of the table as in the example below. :: +---------+---------+ | Col1 | Col2 | +---------+---------+ | value11 | value12 | | value21 | value22 | +---------+---------+ """ if any(not dataclasses.is_dataclass(row) for row in rows): raise ValueError("Expected `rows` to be list of dataclasses") if len(set(map(type, rows))) > 1: raise ValueError("Expected elements of `rows` be of same type.") class Column: """Definition of a column for printing parameters.""" def __init__(self, name, values): = name.capitalize() self.values = values self.width = max(len(v) for v in values + [name]) if column_names is None: if not rows: return "(empty table)" column_names = [ for field in dataclasses.fields(rows[0])] columns = [ Column(name, [value_formatter(getattr(row, name)) for row in rows]) for name in column_names ] var_line_format = "|" + "".join(f" {{: <{c.width}s}} |" for c in columns) sep_line_format = var_line_format.replace(" ", "-").replace("|", "+") header = var_line_format.replace(">", "<").format(*[ for c in columns]) separator = sep_line_format.format(*["" for c in columns]) lines = [separator, header, separator] for i in range(len(rows)): if max_lines and len(lines) >= max_lines - 3: lines.append("[...]") break lines.append(var_line_format.format(*[c.values[i] for c in columns])) lines.append(separator) return "\n".join(lines)
[docs]def get_parameter_overview( params: ParamsContainer, *, include_stats: bool = True, max_lines: Optional[int] = None ) -> str: """Return string with variables names, their shapes, count. Args: params: Dictionary with parameters as NumPy arrays. The dictionary can be nested. include_stats: If ``True``, add columns with mean and std for each variable. max_lines: If not ``None``, the maximum number of variables to include. Returns: A string with a table as in the example below. :: +----------------+---------------+------------+ | Name | Shape | Size | +----------------+---------------+------------+ | FC_1/weights:0 | (63612, 1024) | 65,138,688 | | FC_1/biases:0 | (1024,) | 1,024 | | FC_2/weights:0 | (1024, 32) | 32,768 | | FC_2/biases:0 | (32,) | 32 | +----------------+---------------+------------+ Total: 65,172,512 """ if isinstance(params, (dict, flax.core.FrozenDict)): params = jax.tree_util.tree_map(np.asarray, params) rows = get_parameter_rows(params, include_stats=include_stats) total_weights = count_parameters(params) RowType = ParamRowWithStats if include_stats else ParamRow # Pass in `column_names` to enable rendering empty tables. column_names = [ for field in dataclasses.fields(RowType)] table = make_table(rows, max_lines=max_lines, column_names=column_names) return table + f"\nTotal: {total_weights:,}"