jax.extend.linear_util.cache#
- jax.extend.linear_util.cache(call)[source]#
Memoization decorator for functions taking a WrappedFun as first argument.
- Parameters:
call (
Callable
) – a Python callable that takes a WrappedFun as its first argument. The underlying transforms and params on the WrappedFun are used as part of the memoization cache key.- Returns:
A memoized version of
call
.