Source code for scico.trace

# -*- coding: utf-8 -*-
# Copyright (C) 2024-2025 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
# package.

"""Call tracing of scico functions and methods.

JIT must be disabled for tracing to function correctly (set environment
variable :code:`JAX_DISABLE_JIT=1`, or call
:code:`jax.config.update('jax_disable_jit', True)` before importing `jax`
or `scico`). Call :code:`trace_scico_calls` to initialize tracing, and
call :code:`register_variable` to associate a name with a variable so
that it can be referenced by name in the call trace.

The call trace is color-code as follows if
`colorama <https://github.com/tartley/colorama>`_ is installed:

- `module and class names`: light red
- `function and method names`: dark red
- `arguments and return values`: light blue
- `names of registered variables`: light yellow

When a method defined in a class is called for an object of a derived
class type, the class of that object is displayed in light magenta, in
square brackets. Function names and return values are distinguished by
initial ``>>`` and ``<<`` characters respectively.

A usage example is provided in the script :code:`trace_example.py`.
"""


from __future__ import annotations

import inspect
import sys
import types
import warnings
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Optional, Sequence

import numpy as np

import jax

try:
    from jaxlib.xla_extension import PjitFunction
except ImportError:
    from jaxlib._jax import PjitFunction  # jax >= 0.6.1


try:
    import colorama

    have_colorama = True
except ImportError:
    have_colorama = False


if have_colorama:
    clr_main = colorama.Fore.LIGHTRED_EX  # main trace information
    clr_rvar = colorama.Fore.LIGHTYELLOW_EX  # registered variable names
    clr_self = colorama.Fore.LIGHTMAGENTA_EX  # type of object for which method is called
    clr_func = colorama.Fore.RED  # function/method name
    clr_args = colorama.Fore.LIGHTBLUE_EX  # function/method arguments
    clr_retv = colorama.Fore.LIGHTBLUE_EX  # function/method return values
    clr_devc = colorama.Fore.CYAN  # JAX array device and sharding
    clr_reset = colorama.Fore.RESET  # reset color
else:
    clr_main, clr_rvar, clr_self, clr_func = "", "", "", ""
    clr_args, clr_retv, clr_devc, clr_reset = "", "", "", ""


def _get_hash(val: Any) -> Optional[int]:
    """Get a hash representing an object.

    Args:
        val: An object for which the hash is required.

    Returns:
        A hash value of ``None`` if a hash cannot be computed.
    """
    if isinstance(val, np.ndarray):
        hash = val.ctypes.data  # for an ndarray, hash is the memory address
    elif hasattr(val, "__hash__") and callable(val.__hash__):
        try:
            hash = val.__hash__()
        except TypeError:
            hash = None
    else:
        hash = None
    return hash


def _trace_arg_repr(val: Any) -> str:
    """Compute string representation of function arguments.

    Args:
        val: Argument value

    Returns:
        A string representation of the argument.
    """
    if val is None:
        return "None"
    elif np.isscalar(val):  # a scalar value
        return str(val)
    elif isinstance(val, tuple) and len(val) < 6 and all([np.isscalar(s) for s in val]):
        return f"{val}"  # a short sequence of scalars
    elif isinstance(val, np.dtype):  # a numpy dtype
        return f"numpy.{val}"
    elif isinstance(val, type):  # a class name
        return f"{val.__module__}.{val.__qualname__}"
    elif isinstance(val, np.ndarray) and _get_hash(val) in call_trace.instance_hash:  # type: ignore
        return f"{clr_rvar}{call_trace.instance_hash[_get_hash(val)]}{clr_args}"  # type: ignore
    elif isinstance(val, (np.ndarray, jax.Array)):  # a jax or numpy array
        if val.shape == ():
            return str(val)
        else:
            dev_str, shard_str = "", ""
            if isinstance(val, jax.Array) and not isinstance(
                val, jax._src.interpreters.partial_eval.JaxprTracer
            ):
                if call_trace.show_jax_device:  # type: ignore
                    platform = list(val.devices())[0].platform  # assume all of same type
                    devices = ",".join(map(str, sorted([d.id for d in val.devices()])))
                    dev_str = f"{clr_devc}{{dev={platform}({devices})}}{clr_args}"
                if call_trace.show_jax_sharding and isinstance(  # type: ignore
                    val.sharding, jax._src.sharding_impls.PositionalSharding
                ):
                    shard_str = f"{clr_devc}{{shard={val.sharding.shape}}}{clr_args}"
            return f"Array{val.shape}{dev_str}{shard_str}"
    else:
        if _get_hash(val) in call_trace.instance_hash:  # type: ignore
            return f"{clr_rvar}{call_trace.instance_hash[val.__hash__()]}{clr_args}"  # type: ignore
        else:
            return f"[{type(val).__name__}]"


[docs] def register_variable(var: Any, name: str): """Register a variable name for call tracing. Any hashable object (or numpy array, with the memory address used as a hash) may be registered. JAX arrays may not be registered since they are not hashable and there is no clear mechanism for associating them with a unique memory address. Args: var: The variable to be registered. name: The name to be associated with the variable. """ hash = _get_hash(var) if hash is None: raise ValueError(f"Can't get hash for variable '{name}'.") call_trace.instance_hash[hash] = name # type: ignore
def _call_wrapped_function(func: Callable, *args, **kwargs) -> Any: """Call a wrapped function within the wrapper. Handle different call mechanisms required for static and class methods. Args: func: Wrapped function *args: Positional arguments **kwargs: Named arguments Returns: Return value of wrapped function. """ if isinstance(func, staticmethod): # If the type of the first argument is the same as the class to # which the static method belongs, assume that it was called as # <object>.<staticmethod>(<args>), which requires that the first # argument be stripped before calling the method. This is # somewhat heuristic, and may fail, but there is no obvious # mechanism for reliably determining how the method was called in # the calling scope. if inspect._findclass(func) == type(args[0]): # type: ignore call_args = args[1:] else: call_args = args ret = func(*call_args, **kwargs) elif isinstance(func, classmethod): ret = func.__func__(*args, **kwargs) else: ret = func(*args, **kwargs) return ret
[docs] def call_trace(func: Callable) -> Callable: """Print log of calls to `func`. Decorator for printing a log of calls to the wrapped function. A record of call levels is maintained so that call nesting is indicated by call log indentation. """ try: method_class = inspect._findclass(func) # type: ignore except AttributeError: method_class = None @wraps(func) def wrapper(*args, **kwargs): name = f"{func.__module__}.{clr_func}{func.__qualname__}" arg_idx = 0 if ( args and hasattr(args[0], "__hash__") and callable(args[0].__hash__) and method_class and isinstance(args[0], method_class) ): # first argument is self for a method call arg_idx = 1 # skip self in handling arguments if args[0].__hash__() in call_trace.instance_hash: # self object registered using register_variable name = ( f"{clr_rvar}{call_trace.instance_hash[args[0].__hash__()]}." f"{clr_func}{func.__name__}" ) else: # self object not registered func_class = method_class.__name__ self_class = args[0].__class__.__name__ # If the class in which this method is defined is same as that # of the self object for which it's called, just display the # class name. Otherwise, display the name of the name defining # class followed by the name of the self object class in # square brackets. if func_class == self_class: class_name = func_class else: class_name = f"{func_class}{clr_self}[{self_class}]{clr_main}" name = f"{func.__module__}.{class_name}.{clr_func}{func.__name__}" args_repr = [_trace_arg_repr(val) for val in args[arg_idx:]] kwargs_repr = [f"{key}={_trace_arg_repr(val)}" for key, val in kwargs.items()] args_str = clr_args + ", ".join(args_repr + kwargs_repr) + clr_main print( f"{clr_main}>> {' ' * 2 * call_trace.trace_level}{name}" f"({args_str}{clr_func}){clr_reset}", file=sys.stderr, ) # call wrapped function call_trace.trace_level += 1 ret = _call_wrapped_function(func, *args, **kwargs) call_trace.trace_level -= 1 # print representation of return value if ret is not None and call_trace.show_return_value: print( f"{clr_main}<< {' ' * 2 * call_trace.trace_level}{clr_retv}" f"{_trace_arg_repr(ret)}{clr_reset}", file=sys.stderr, ) return ret # Set flag indicating that function is already wrapped wrapper._call_trace_wrap = True # type: ignore # Avoid multiple wrapper layers if hasattr(func, "_call_trace_wrap"): return func else: return wrapper
# call level counter for call_trace decorator call_trace.trace_level = 0 # type: ignore # hash dict allowing association of objects with variable names call_trace.instance_hash = {} # type: ignore # flag indicating whether to show function return value call_trace.show_return_value = True # type: ignore # flag indicating whether to show JAX array devices call_trace.show_jax_device = False # type: ignore # flag indicating whether to show JAX array sharding shape call_trace.show_jax_sharding = False # type: ignore def _submodule_name(module, obj): if ( len(obj.__name__) > len(module.__name__) and obj.__name__[0 : len(module.__name__)] == module.__name__ ): short_name = obj.__name__[len(module.__name__) + 1 :] else: short_name = "" return short_name def _is_scico_object(obj: Any) -> bool: """Determine whether an object is defined in a scico submodule. Args: obj: Object to check. Returns: A boolean value indicating whether `obj` is defined in a scico submodule. """ return hasattr(obj, "__module__") and obj.__module__[0:5] == "scico" def _is_scico_module(mod: types.ModuleType) -> bool: """Determine whether a module is a scico submodule. Args: mod: Module to check. Returns: A boolean value indicating whether `mod` is a scico submodule. """ return hasattr(mod, "__name__") and mod.__name__[0:5] == "scico" def _in_module(mod: types.ModuleType, obj: Any) -> bool: """Determine whether an object is defined in a module. Args: mod: Module of interest. obj: Object to check. Returns: A boolean value indicating whether `obj` is defined in `mod`. """ return obj.__module__ == mod.__name__ def _is_submodule(mod: types.ModuleType, submod: types.ModuleType) -> bool: """Determine whether a module is a submodule of another module. Args: mod: Parent module of interest. submod: Possible submodule to check. Returns: A boolean value indicating whether `submod` is defined in `mod`. """ return submod.__name__[0 : len(mod.__name__)] == mod.__name__
[docs] def apply_decorator( module: types.ModuleType, decorator: Callable, recursive: bool = True, skip: Optional[Sequence] = None, seen: Optional[defaultdict[str, int]] = None, verbose: bool = False, level: int = 0, ) -> defaultdict[str, int]: """Apply a decorator function to all functions in a scico module. Apply a decorator function to all functions in a scico module, including methods of classes in that module. Args: module: The module containing the functions/methods to be decorated. decorator: The decorator function to apply to each module function/method. recursive: Flag indicating whether to recurse into submodules of the specified module. (Hidden modules with a name starting with an underscore are ignored.) skip: A list of class/function/method names to be skipped. seen: A :class:`defaultdict` providing a count of the number of times each function/method was seen. verbose: Flag indicating whether to print a log of functions as they are encountered. level: Counter for recursive call levels. Returns: A :class:`defaultdict` providing a count of the number of times each function/method was seen. """ indent = " " * 4 * level if skip is None: skip = [] if seen is None: seen = defaultdict(int) if verbose: print(f"{indent}Module: {module.__name__}") indent += " " * 4 # Iterate over functions in module for name, func in inspect.getmembers( module, lambda obj: isinstance(obj, (types.FunctionType, PjitFunction)) and _in_module(module, obj), ): if name in skip: continue qualname = func.__module__ + "." + func.__qualname__ if not seen[qualname]: # avoid multiple applications of decorator setattr(module, name, decorator(func)) seen[qualname] += 1 if verbose: print(f"{indent}Function: {qualname}") # Iterate over classes in module for name, cls in inspect.getmembers( module, lambda obj: inspect.isclass(obj) and _in_module(module, obj) ): qualname = cls.__module__ + "." + cls.__qualname__ # type: ignore if verbose: print(f"{indent}Class: {qualname}") # Iterate over methods in class for name, func in inspect.getmembers( cls, lambda obj: isinstance(obj, (types.FunctionType, PjitFunction)) ): if name in skip: continue qualname = func.__module__ + "." + func.__qualname__ # type: ignore if not seen[qualname]: # avoid multiple applications of decorator # Can't use cls returned by inspect.getmembers because it uses plain # getattr internally, which interferes with identification of static # methods. From Python 3.11 onwards one could use # inspect.getmembers_static instead of inspect.getmembers, but that # would imply incompatibility with earlier Python versions. func = inspect.getattr_static(cls, name) setattr(cls, name, decorator(func)) seen[qualname] += 1 if verbose: print(f"{indent + ' '}Method: {qualname}") # Iterate over submodules of module if recursive: for name, mod in inspect.getmembers( module, lambda obj: inspect.ismodule(obj) and _is_submodule(module, obj) ): if name[0:1] == "_": continue seen = apply_decorator( mod, decorator, recursive=recursive, skip=skip, seen=seen, verbose=verbose, level=level + 1, ) return seen
[docs] def trace_scico_calls(verbose: bool = False): """Enable tracing of calls to all significant scico functions/methods. Enable tracing of calls to all significant scico functions and methods. Note that JIT should be disabled to ensure correct functioning of the tracing mechanism. """ if not jax.config.jax_disable_jit: warnings.warn( "Call tracing requested but jit is not disabled. Disable jit" " by setting the environment variable JAX_DISABLE_JIT=1, or use" " jax.config.update('jax_disable_jit', True)." ) from scico import ( function, functional, linop, loss, metric, operator, optimize, solver, ) seen = None for module in (functional, linop, loss, operator, optimize, function, metric, solver): seen = apply_decorator(module, call_trace, skip=["__repr__"], seen=seen, verbose=verbose)