scico.flax.train.clu_utils#
Utilities for displaying Flax models.
Functions
|
Flatten keys of a nested dictionary. |
|
Return string with variables names, their shapes, count. |
|
Return information about parameters as a list of dictionaries. |
|
Render list of rows to a table. |
Classes
|
Definition of the structure of a row for printing parameters without stats. |
|
Definition of the structure of a row for printing parameters with stats. |
- class scico.flax.train.clu_utils.ParamRow(name, shape, size)[source]#
Bases:
object
Definition of the structure of a row for printing parameters without stats.
- class scico.flax.train.clu_utils.ParamRowWithStats(name, shape, size, mean, std)[source]#
Bases:
ParamRow
Definition of the structure of a row for printing parameters with stats.
- scico.flax.train.clu_utils.flatten_dict(input_dict, prefix='', delimiter='/')[source]#
Flatten keys of a nested dictionary.
- scico.flax.train.clu_utils.get_parameter_rows(params, *, include_stats=False)[source]#
Return information about parameters as a list of dictionaries.
- Parameters:
params (
Union
[Dict
[str
,ndarray
],Mapping
[str
,Mapping
[str
,Any
]]]) – Dictionary with parameters as NumPy arrays. The dictionary can be nested.include_stats (
bool
) – IfTrue
add columns with mean and std for each variable. Note that this can be considerably more compute intensive and cause a lot of memory to be transferred to the host.
- Return type:
- Returns:
A list of ParamRow, or ParamRowWithStats, depending on the passed value of include_stats.
- scico.flax.train.clu_utils.make_table(rows, *, column_names=None, value_formatter=<function _default_table_value_formatter>, max_lines=None)[source]#
Render list of rows to a table.
- Parameters:
rows (
List
[Any
]) – List of dataclass instances of a single type (e.g. ParamRow).column_names (
Optional
[Sequence
[str
]]) – List of columns that that should be included in the output. If not provided, then the columns are taken from keys of the first row.value_formatter (
Callable
[[Any
],str
]) – Callable used to format cell values.max_lines (
Optional
[int
]) – Don’t render a table longer than this.
- Return type:
- Returns:
A string representation of the table as in the example below.
+---------+---------+ | Col1 | Col2 | +---------+---------+ | value11 | value12 | | value21 | value22 | +---------+---------+
- scico.flax.train.clu_utils.get_parameter_overview(params, *, include_stats=True, max_lines=None)[source]#
Return string with variables names, their shapes, count.
- Parameters:
params (
Union
[Dict
[str
,ndarray
],Mapping
[str
,Mapping
[str
,Any
]]]) – Dictionary with parameters as NumPy arrays. The dictionary can be nested.include_stats (
bool
) – IfTrue
, add columns with mean and std for each variable.max_lines (
Optional
[int
]) – If notNone
, the maximum number of variables to include.
- Return type:
- Returns:
A string with a table as in the example below.
+----------------+---------------+------------+ | Name | Shape | Size | +----------------+---------------+------------+ | FC_1/weights:0 | (63612, 1024) | 65,138,688 | | FC_1/biases:0 | (1024,) | 1,024 | | FC_2/weights:0 | (1024, 32) | 32,768 | | FC_2/biases:0 | (32,) | 32 | +----------------+---------------+------------+ Total: 65,172,512