jax.tree_util.tree_leaves_with_path#
- jax.tree_util.tree_leaves_with_path(tree, is_leaf=None)[source]#
Gets the leaves of a pytree like
tree_leaves
and returns each leaf’s key path.
Gets the leaves of a pytree like tree_leaves
and returns each leaf’s key path.