jax.extend.linear_util.cache#

jax.extend.linear_util.cache(call)[source]#

Memoization decorator for functions taking a WrappedFun as first argument.

Parameters:

call (Callable) – a Python callable that takes a WrappedFun as its first argument. The underlying transforms and params on the WrappedFun are used as part of the memoization cache key.

Returns:

A memoized version of call.