jax.tree.all#
- jax.tree.all(tree)#
Call all() over the leaves of a tree.
Alias of
jax.tree_util.tree_all().- Parameters:
tree (
Any) – the pytree to evaluate- Returns:
boolean True or False
- Return type:
result
Examples
>>> import jax >>> jax.tree.all([True, {'a': True, 'b': (True, True)}]) True >>> jax.tree.all([False, (True, False)]) False