jax.tree_util.tree_flatten_with_path#
- jax.tree_util.tree_flatten_with_path(tree, is_leaf=None)[source]#
Flattens a pytree like
tree_flatten
, but also returns each leaf’s key path.- Parameters:
- Return type:
tuple
[list
[tuple
[tuple
[TypeVar
(KeyEntry
, bound=Hashable
),...
],Any
]],PyTreeDef
]- Returns:
A pair which the first element is a list of key-leaf pairs, each of which contains a leaf and its key path. The second element is a treedef representing the structure of the flattened tree.