scico.flax.train.clu_utils#

Utilities for displaying Flax models.

Functions

flatten_dict(input_dict[, prefix, delimiter])

Flatten keys of a nested dictionary.

get_parameter_overview(params, *[, ...])

Return string with variables names, their shapes, count.

get_parameter_rows(params, *[, include_stats])

Return information about parameters as a list of dictionaries.

make_table(rows, *[, column_names, ...])

Render list of rows to a table.

Classes

ParamRow(name, shape, size)

Definition of the structure of a row for printing parameters without stats.

ParamRowWithStats(name, shape, size, mean, std)

Definition of the structure of a row for printing parameters with stats.

class scico.flax.train.clu_utils.ParamRow(name, shape, size)[source]#

Bases: object

Definition of the structure of a row for printing parameters without stats.

class scico.flax.train.clu_utils.ParamRowWithStats(name, shape, size, mean, std)[source]#

Bases: ParamRow

Definition of the structure of a row for printing parameters with stats.

Inheritance diagram of ParamRowWithStats

scico.flax.train.clu_utils.flatten_dict(input_dict, prefix='', delimiter='/')[source]#

Flatten keys of a nested dictionary.

Parameters:
  • input_dict (Dict[str, Any]) – Nested dictionary.

  • prefix (str) – Prefix of already flatten. Default: empty string.

  • delimiter (str) – Delimiter for displaying. Default: /.

Return type:

Dict[str, Any]

Returns:

A dictionary with the keys flattened.

scico.flax.train.clu_utils.get_parameter_rows(params, *, include_stats=False)[source]#

Return information about parameters as a list of dictionaries.

Parameters:
  • params (Union[Dict[str, ndarray], Mapping[str, Mapping[str, Any]]]) – Dictionary with parameters as NumPy arrays. The dictionary can be nested.

  • include_stats (bool) – If True add columns with mean and std for each variable. Note that this can be considerably more compute intensive and cause a lot of memory to be transferred to the host.

Return type:

List[Union[ParamRow, ParamRowWithStats]]

Returns:

A list of ParamRow, or ParamRowWithStats, depending on the passed value of include_stats.

scico.flax.train.clu_utils.make_table(rows, *, column_names=None, value_formatter=<function _default_table_value_formatter>, max_lines=None)[source]#

Render list of rows to a table.

Parameters:
  • rows (List[Any]) – List of dataclass instances of a single type (e.g. ParamRow).

  • column_names (Optional[Sequence[str]]) – List of columns that that should be included in the output. If not provided, then the columns are taken from keys of the first row.

  • value_formatter (Callable[[Any], str]) – Callable used to format cell values.

  • max_lines (Optional[int]) – Don’t render a table longer than this.

Return type:

str

Returns:

A string representation of the table as in the example below.

+---------+---------+
| Col1    | Col2    |
+---------+---------+
| value11 | value12 |
| value21 | value22 |
+---------+---------+

scico.flax.train.clu_utils.get_parameter_overview(params, *, include_stats=True, max_lines=None)[source]#

Return string with variables names, their shapes, count.

Parameters:
  • params (Union[Dict[str, ndarray], Mapping[str, Mapping[str, Any]]]) – Dictionary with parameters as NumPy arrays. The dictionary can be nested.

  • include_stats (bool) – If True, add columns with mean and std for each variable.

  • max_lines (Optional[int]) – If not None, the maximum number of variables to include.

Return type:

str

Returns:

A string with a table as in the example below.

+----------------+---------------+------------+
| Name           | Shape         | Size       |
+----------------+---------------+------------+
| FC_1/weights:0 | (63612, 1024) | 65,138,688 |
| FC_1/biases:0  |       (1024,) |      1,024 |
| FC_2/weights:0 |    (1024, 32) |     32,768 |
| FC_2/biases:0  |         (32,) |         32 |
+----------------+---------------+------------+

Total: 65,172,512